BPTT

홍종현·2022년 4월 27일

Backpropagation Through Time

BPTT는 RNN에서 계산되는 back propagation이다. RNN의 구조는 sequential하기 때문에 이에 따라 발생하는 hidden state를 따라 역행하면서 전파되는 gradient의 계산 방법이다. 다음은 RNN의 기본 구조이다.

st=tahn(Uxt+Wst1)s_t = tahn(Ux_t + W_{s_{t-1}})
ytˉ=sotfmax(Vst)\bar{y_t} = sotfmax(Vs_t)

UU: input을 연결하는 가중치 행렬
WW: 현재 hidden state와 다음 hidden state를 연결하는 가중치 행렬
VV: output을 연결하는 가중치 행렬

이때 첫번째 셀의 loss값을 다음과 같이 정리할 수 있다.

E1=(d1ˉy1ˉ)2E_1 = (\bar{d_1} -\bar{y_1})^2

만약 Timestep = 3이라면 현재 model에서 update해야되는 weight matrix는 UU, WW, VV 3개가 존재합니다.

그렇다면 UU의 gradient를 구하기 위하여 E3E_3에 대하여 UU로 편미분한 결과를 chain rule을 통해 구할 수 있다.

E3U=E3y3ˉy3U\frac{\partial E_3}{\partial U} = \frac{\partial E_3}{\partial \bar{y_3}} \cdot \frac{\partial y_3}{\partial U}

다음 WW의 gradient를 구하면 다음과 같다.

E3W=E3y3ˉy3W\frac{\partial E_3}{\partial W} = \frac{\partial E_3}{\partial \bar{y_3}} \cdot \frac{\partial y_3}{\partial W}
E3W=E3y3ˉy3s3s3W\frac{\partial E_3}{\partial W} = \frac{\partial E_3}{\partial \bar{y_3}} \cdot \frac{\partial y_3}{\partial s_3} \cdot \frac{\partial s_3}{\partial W}

하지만 기존 backpropagation과는 다른 점이 각 timestep이 gradient에 영향을 주었기 때문에 s3W\frac{\partial s_3}{\partial W}를 상수 취급할 수 없다. 현재 timestep이 3인 출력 부분까지의 이전 timestep이 적용되기 때문에 t=2, t=1 까지 gradient를 전부 더해야한다. 그렇다면 다음과 같이 정리할 수 있다.

E3W=E3y3ˉy3s3s3W+E3y3ˉy3s3s3s2s2W+E3y3ˉy3s3s3s2s2s1s1W\frac{\partial E_3}{\partial W} = \frac{\partial E_3}{\partial \bar{y_3}} \cdot \frac{\partial y_3}{\partial s_3} \cdot \frac{\partial s_3}{\partial W} + \frac{\partial E_3}{\partial \bar{y_3}} \cdot \frac{\partial y_3}{\partial s_3} \cdot \frac{\partial s_3}{\partial s_2} \cdot \frac{\partial s_2}{\partial W} + \frac{\partial E_3}{\partial \bar{y_3}} \cdot \frac{\partial y_3}{\partial s_3} \cdot \frac{\partial s_3}{\partial s_2} \cdot \frac{\partial s_2}{\partial s_1} \cdot \frac{\partial s_1}{\partial W}

간단하게 정리하면 다음과 같은 식으로 정리할 수 있다.

E3W=i=03E3y3ˉy3s3s3sisiW\frac{\partial E_3}{\partial W} = \sum^3_{i=0}\frac{\partial E_3}{\partial \bar{y_3}} \cdot \frac{\partial y_3}{\partial s_3} \cdot \frac{\partial s_3}{\partial s_i} \cdot \frac{\partial s_i}{\partial W}

그리고 s3si\frac{\partial s_3}{\partial s_i}도 chain rule을 내포하고 있기 때문에 s3si=s3s2s2s1\frac{\partial s_3}{\partial s_i} = \frac{\partial s_3}{\partial s_2} \cdot \frac{\partial s_2}{\partial s_1}로 나타낼 수 있다. 위의 gradient를 다시 써보면 다음과 같이 나타낼 수 있다.

E3W=i=03E3y3ˉy3s3(j=i+13sjsj1)siW\frac{\partial E_3}{\partial W} = \sum^3_{i=0}\frac{\partial E_3}{\partial \bar{y_3}} \cdot \frac{\partial y_3}{\partial s_3} \bigg( \prod^3_{j=i+1}\frac{\partial s_{j}}{\partial s_{j-1}}\bigg) \cdot \frac{\partial s_i}{\partial W}

하지만 BPTT를 진행하다보면 처음부터 끝까지 모든 loss를 backpropagation해야하기 때문에 계샨량이 너무 많아지는데, 이것을 줄이고자 Truncated-Backpropagation Through Time(생략된-BPTT)를 많이 사용한다.

0개의 댓글