참고 자료:
http://cs231n.stanford.edu/handouts/linear-backprop.pdf
순전파
Y=outputX=inputW=weightY=XW
Y, X, W는 모두 행렬이다. W가 학습되는 파라미터이다. X로부터 들어온 입력값으로 W를 행렬곱 하고 Y로 출력한다. Y는 다음 레이어의 입력으로 들어가고, 마지막에는 Y를 통해 loss값인 l을 구할 수 있다.
역전파
l이 구해지고 뒷 레이어부터 그래디언트가 차례대로 구해졌을 것으로 가정한다. 즉 ∂Y∂l를 알고 있다. 이 상황에서 해야할 일은 다음과 같다.
- ∂X∂l를 구해 앞 레이어로 전달.
- ∂W∂l를 구해 현재 레이어의 파라미터인 W를 업데이트
Y=⎣⎢⎢⎢⎢⎡y11y21⋮yn1y12y22⋮yn2⋯⋯⋱y1my2mynm⎦⎥⎥⎥⎥⎤X=⎣⎢⎢⎢⎢⎡x11x21⋮xn1x12x22⋮xn2⋯⋯⋱x1dx2dxnd⎦⎥⎥⎥⎥⎤W=⎣⎢⎢⎢⎢⎡w11w21⋮wd1w12w22⋮wd2⋯⋯⋱w1mw2mwdm⎦⎥⎥⎥⎥⎤
라고 하자. 그러면 구하고자 하는 것은
∂X∂l=⎣⎢⎢⎢⎢⎡∂x11∂l∂x12∂l⋮∂x1m∂l∂x21∂l∂x22∂l⋮∂x2m∂l⋯⋯⋱∂xn1∂l∂xn2∂l∂xnm∂l⎦⎥⎥⎥⎥⎤
이다. 항목 하나씩 구해보는 거로 생각해보자.
∂x11∂l=i=1∑nj=1∑m∂yij∂l∂x11∂yij
이다. 편미분의 정의에 의해 x11에 의해 Y가 변화하고, Y의 각각의 값의 변화량이 l에 영향을 미치므로 위와 같이 합으로 나타낸다. ∂yij∂l는 이미 주어진 ∂Y∂l로 알 수 있고 ∂x11∂yij는 Y=XW식에서 yij로 구할 수 있다.
y11=x11w11+x12w21+⋯+x1dwd1∂x11∂y11=w11y12=x11w12+x12w22+⋯+x1dwd2∂x11∂y12=w12y21=x21w11+x22w21+⋯+x2dwd1∂x11∂y21=0
이렇게 반복하다보면 ∂X∂l를 완전히 구할 수 있다.
행렬로 나타내어 식을 정리하면 아래와 같이 나타낼 수 있고 같은 방식으로 W에 대해서도 적용할 수 있다.
∂X∂l=∂Y∂lWT∂W∂l=XT∂Y∂l
Y=XW+b와 같이 bias가 있는 경우에도 같은 방식으로 구할 수 있다.
∂b∂l=∂Y∂l⎣⎢⎢⎢⎢⎡11⋮1⎦⎥⎥⎥⎥⎤
∂Y∂l를 행으로 더한 값이다.