BPTT(Backpropagation Through Time)
Dive into Deep Learning 교재의 BPTT부분을 발췌하여 학습하고자 한다.
BPTT는 교재 8단원 7절에 해당한다. 1~6절 앞부분을 미리 읽고 와도 도움이 될 듯하다. 이번 글에서는 RNN에서 시퀀스 모델에 대한 역전파 알고리즘의 디테일에 대해 알아보고 어떻게 수학이 활용되었는지 알아보겠다.
가장 먼저, 시퀀스 모델에서 Gradient가 어떻게 계산되는지 리뷰해보겠다. MLP의 역전파에서도 그랬듯, Gradient의 계산에는 Chain Rule을 사용한다. 이전 글에서 역전파에 대해 다뤘던 적이 있었으니 참고해도 좋다.
BPTT는 RNN의 역전파 방법들 중 하나로, 모델 변수와 파라미터의 의존성을 얻기 위해 한 번에 한 단계씩 RNN의 Computational Graph를 확장시켜야 한다. 그리고나서 ChainRule에 기반하여 Gradient를 계산하고 저장할 수 있도록 역전파를 진행하는 것이다.
Analysis of Gradients in RNNs
단순화된 RNN 모델을 이용해서 시작해보자. 이 모델은 Hidden State가 어떻게 업데이트되는지에 대한 디테일은 어느정도 무시하는 모델이다. 또, 아래 서술된 수식들은 scalars,vectors, matrices를 명시적으로 구별하는 것들은 아니다. 그런 것들은 분석에서 크게 중요하지 않으며 표기를 더욱 혼란스럽게만 할 뿐이다.
ht는 hidden state, xt는 input, ot는 output at time step t로 하자. 이와 같이, 은닉층과 출력층의 가중치로써 wh와 wo를 사용하겠다. 결과적으로 time step에 따른 Hidden state와 출력은 다음과 같이 나타낼 수 있다.
htot=f(xt,ht−1,wh),=g(ht,wo),
f,g는 각각 은닉층과 출력층의 변환을 의미한다.
이 때, 우리는 Desired label값 y에 대해서 모든 T time step의 목적함수 L을 다음과 같이 작성할 수 있다.
L(x1,…,xT,y1,…,yT,wh,wo)=T1∑t=1Tl(yt,ot)
역전파에서, 특히 L의 wh에 대한 gradients를 계산할 때는 조금 더 Trickier하다. 자세히 표현하자면, Chain Rule에 의해
∂wh∂L=T1t=1∑T∂wh∂l(yt,ot)=T1t=1∑T∂ot∂l(yt,ot)∂ht∂g(ht,wo)∂wh∂ht
로 표현되고, 첫 번째와 두 번째 factor들은 쉽게 계산 가능하다. 세 번째 factor인 ∂ht/∂wh는 wh와 ht의 효과를 순환적으로 계산하기에, 조금 까다롭다.
처음 정의에 의해 ht는 ht−1과 wh에 의존한다. 또한 ht−1도 wh에 의존하고 있기에 Chain Rule에 의해 다음과 같이 나타낼 수 있다.
∂wh∂ht=∂wh∂f(xt,ht−1,wh)+∂ht−1∂f(xt,ht−1,wh)∂wh∂ht−1
위의 Gradients는 세 개의 시퀀스를 가지고 있는데, {at},{bt},{ct} 다음을 만족한다. a0=0 and at=bt+ctat−1 for t=1,2,… for t≥1, 는 다음과 같이 나타낼 수 있다.
at=bt+i=1∑t−1(j=i+1∏tcj)bi
at,bt, and ct 다음과 같이 대체할 수 있다.
at=∂wh∂htbt=∂wh∂f(xt,ht−1,wh)ct=∂ht−1∂f(xt,ht−1,wh)
결과적으로 다음과 같은 결과값을 얻을 수 있다.
∂wh∂ht=∂wh∂f(xt,ht−1,wh)+∑i=1t−1(∏j=i+1t∂hj−1∂f(xj,hj−1,wh))∂wh∂f(xi,hi−1,wh)