TensorFlow - tf.GradientTape의 원리

Michael Kim·2021년 7월 7일
1

0. 공부 배경

Tensorflow 2.0의 Eager모드에서는 다음과 같이 구성되어 다이나믹 연산이 이루어진다.
배우면서 생각해보니 python의 with ... as의 원리가 어떤 건지 생각해본 적이 없었다.
여기에서 with tf.GradientTape() as tape는 어떻게 돌아가는 것일까?

1. Python - with...as 구문

Python Docs에서는...

=> 간단하게 요약하자면 with를 만났을 때, enter()가 호출되고, enter()의 반환값이 as 뒤의 변수에 저장된다. 그리고 with 구문을 나가면서 exit()함수가 실행되고, 설정된 예외함수가 실행된다는 것이다.

2. tf.GradientTape

2-1. enter(), exit() method

이번에는 tensorflow의 Github를 뒤져보자 tf.GradientTape

역시나 enter와 exit함수가 있었다. with함수를 만났을 때 GradientTape의 인스턴스가 as 뒤의 Tape 변수에 담겨 객체로 생성되는 것처럼 보인다.
그리고 그 전에 push_tape 함수가 호출되는 것을 볼 수 있다.

간단하게 해석하자면, tape(auto grad graph)가 이미 기록되고(돌아가고)있으면 오류를 발생시키고, 중단된 tape가 있으면 이어서 실행, 그리고 없다면 tape.push_new_tape()를 호출하여 새로운 tape를 stack에 push한다. 그리고 push된 tape를 객체의 _ tape 변수에 저장되는 원리다. 그리고 기록 중이라는 flag를 True로 띄운다

with구문을 나가게 되면 exit()가 호출된다. exit함수는 아주 간단하다. tape가 기록 중에 있으면 pop_tape()를 호출한다.
그리고 pop_tape는...

tape가 기록 중이지 않을 때 오류를 발생시키고, 기록 중이었다면, 기록한 tape_를 tape의 pop_tape함수에 담아 실행시키면 stack에 있는 tape가 pop된다(pop된 tape는 GradientTape 인스턴스의 _ tape에 저장된 상태인 것으로 추측이 된다). 그리고나서 기록 flag를 False로 바꿔주면서 기록을 끝낸다.

2-2. tf.GradientTape.gradient()

마지막으로 tf.GradientTape.gradient(self, target, sources, output_gradients, unconnected_gradients) 함수로 grad를 반환받는다.

target: 미분을 실행할 Tensor or Variable 가 담긴 구조체를 넣는다.
sources: target의 미분 대상인 Tensor or Variable들을 넣는다.
output_gradients:
unconnected_gradients: sources에 대한 target의 미분 값이 0일 경우(target의 변수에 sources가 없거나, relu같은 활성화함수에 의해 0을 미분할 경우) 반환값을 어떤 값으로 대체할 지 정하도록 한다. NONE, ZERO를 넣을 수 있다. (계산과정에서 어떠한 차이가 있을 지는 모르겠다...)

요약하자면, 함수의 입력 변수들과 기록된 tape 정보들을 고려하여 imperative_grad 함수를 통해 기울기를 계산해낸다. 그리고 계산된 가중치의 기울기(grad)를 반환한다. 다시말해, gradient함수를 호출될 때 기울기가 계산되어 반환된다.

profile
정리하고 복습하고 일기도 쓰고

1개의 댓글

comment-user-thumbnail
2022년 5월 6일

설명잘보고갑니다 이해빡되네요
북마크목록에넣었습니다

답글 달기