Linear Layer의 역전파 수식 알아보기

seokj·2023년 2월 11일
0

참고 자료:
http://cs231n.stanford.edu/handouts/linear-backprop.pdf


순전파

Y=outputX=inputW=weightY=XW\text{Y}=\text{output} \quad \text{X}=\text{input} \quad \text{W}=\text{weight} \\ \text{Y}=\text{XW}

Y, X, W\text{Y, X, W}는 모두 행렬이다. W\text{W}가 학습되는 파라미터이다. X\text{X}로부터 들어온 입력값으로 W\text{W}를 행렬곱 하고 Y\text{Y}로 출력한다. Y\text{Y}는 다음 레이어의 입력으로 들어가고, 마지막에는 Y\text{Y}를 통해 loss값인 ll을 구할 수 있다.

역전파

ll이 구해지고 뒷 레이어부터 그래디언트가 차례대로 구해졌을 것으로 가정한다. 즉 lY\frac{\partial l}{\partial \text{Y}}를 알고 있다. 이 상황에서 해야할 일은 다음과 같다.

  • lX\frac{\partial l}{\partial \text{X}}를 구해 앞 레이어로 전달.
  • lW\frac{\partial l}{\partial \text{W}}를 구해 현재 레이어의 파라미터인 W\text{W}를 업데이트
Y=[y11y12y1my21y22y2myn1yn2ynm]X=[x11x12x1dx21x22x2dxn1xn2xnd]W=[w11w12w1mw21w22w2mwd1wd2wdm]\text{Y}= \begin{bmatrix} y_{11} & y_{12} & \cdots & y_{1m} \\ y_{21} & y_{22} & \cdots & y_{2m} \\ \vdots & \vdots & \ddots\\ y_{n1} & y_{n2} & & y_{nm} \\ \end{bmatrix} \quad \text{X}= \begin{bmatrix} x_{11} & x_{12} & \cdots & x_{1d} \\ x_{21} & x_{22} & \cdots & x_{2d} \\ \vdots & \vdots & \ddots\\ x_{n1} & x_{n2} & & x_{nd} \\ \end{bmatrix} \quad \text{W}= \begin{bmatrix} w_{11} & w_{12} & \cdots & w_{1m} \\ w_{21} & w_{22} & \cdots & w_{2m} \\ \vdots & \vdots & \ddots\\ w_{d1} & w_{d2} & & w_{dm} \\ \end{bmatrix}

라고 하자. 그러면 구하고자 하는 것은

lX=[lx11lx21lxn1lx12lx22lxn2lx1mlx2mlxnm]\frac{\partial l}{\partial \text{X}}= \begin{bmatrix} \frac{\partial l}{\partial x_{11}} & \frac{\partial l}{\partial x_{21}} & \cdots & \frac{\partial l}{\partial x_{n1}} \\ \frac{\partial l}{\partial x_{12}} & \frac{\partial l}{\partial x_{22}} & \cdots & \frac{\partial l}{\partial x_{n2}} \\ \vdots & \vdots & \ddots\\ \frac{\partial l}{\partial x_{1m}} & \frac{\partial l}{\partial x_{2m}} & & \frac{\partial l}{\partial x_{nm}} \\ \end{bmatrix}

이다. 항목 하나씩 구해보는 거로 생각해보자.

lx11=i=1nj=1mlyijyijx11\frac{\partial l}{\partial x_{11}}=\sum_{i=1}^{n}\sum_{j=1}^m \frac{\partial l}{\partial y_{ij}} \frac{\partial y_{ij}}{\partial x_{11}}

이다. 편미분의 정의에 의해 x11x_{11}에 의해 Y\text{Y}가 변화하고, Y\text{Y}의 각각의 값의 변화량이 ll에 영향을 미치므로 위와 같이 합으로 나타낸다. lyij\frac{\partial l}{\partial y_{ij}}는 이미 주어진 lY\frac{\partial l}{\partial \text{Y}}로 알 수 있고 yijx11\frac{\partial y_{ij}}{\partial x_{11}}Y=XW\text{Y}=\text{XW}식에서 yijy_{ij}로 구할 수 있다.

y11=x11w11+x12w21++x1dwd1y11x11=w11y12=x11w12+x12w22++x1dwd2y12x11=w12y21=x21w11+x22w21++x2dwd1y21x11=0y_{11}=x_{11}w_{11}+x_{12}w_{21}+\cdots+x_{1d}w_{d1}\\ \frac{\partial y_{11}}{\partial x_{11}}=w_{11}\\ y_{12}=x_{11}w_{12}+x_{12}w_{22}+\cdots+x_{1d}w_{d2}\\ \frac{\partial y_{12}}{\partial x_{11}}=w_{12}\\ y_{21}=x_{21}w_{11}+x_{22}w_{21}+\cdots+x_{2d}w_{d1}\\ \frac{\partial y_{21}}{\partial x_{11}}=0\\

이렇게 반복하다보면 lX\frac{\partial l}{\partial \text{X}}를 완전히 구할 수 있다.


행렬로 나타내어 식을 정리하면 아래와 같이 나타낼 수 있고 같은 방식으로 W\text{W}에 대해서도 적용할 수 있다.

lX=lYWTlW=XTlY\frac{\partial l}{\partial X}=\frac{\partial l}{\partial Y}W^T\\ \frac{\partial l}{\partial W}=X^T \frac{\partial l}{\partial Y}\\

Y=XW+b\text{Y}=\text{XW}+\text{b}와 같이 bias가 있는 경우에도 같은 방식으로 구할 수 있다.

lb=lY[111]\frac{\partial l}{\partial \text{b}}=\frac{\partial l}{\partial Y} \begin{bmatrix} 1\\1\\ \vdots \\1 \end{bmatrix}

lY\frac{\partial l}{\partial Y}를 행으로 더한 값이다.

profile
안녕하세요

0개의 댓글