Matrix Derivative
기초
A 벡터는 상수만 들어있을 때 다음과 같이 구할 수 있다
A = ( 1 2 3 4 ) A = \begin{pmatrix} 1 & 2 \\ 3 & 4 \end{pmatrix} A = ( 1 3 2 4 )
A x ⃗ = ( 1 2 3 4 ) ( x 1 x 2 ) = ( x 1 + 2 x 2 3 x 1 + 4 x 2 ) = ( f 1 ( x 1 , x 2 ) f 2 ( x 1 , x 2 ) ) A\vec{x} = \begin{pmatrix} 1 & 2 \\ 3 & 4 \end{pmatrix}\begin{pmatrix} x_1 \\ x_2 \end{pmatrix} = \begin{pmatrix} x_1+2x_2 \\ 3x_1+4x_2 \end{pmatrix} = \begin{pmatrix} f_1(x_1, x_2) \\ f_2(x_1, x_2) \end{pmatrix} A x = ( 1 3 2 4 ) ( x 1 x 2 ) = ( x 1 + 2 x 2 3 x 1 + 4 x 2 ) = ( f 1 ( x 1 , x 2 ) f 2 ( x 1 , x 2 ) )
d d x A x ⃗ = [ d f 1 d x 1 d f 1 d x 2 d f 2 d x 1 d f 2 d x 2 ] = [ 1 2 3 4 ] \frac{d}{dx}A\vec{x} = \begin{bmatrix} \frac{df_1}{dx_1} & \frac{df_1}{dx_2}\\ \frac{df_2}{dx_1} & \frac{df_2}{dx_2} \end{bmatrix} = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} d x d A x = [ d x 1 d f 1 d x 1 d f 2 d x 2 d f 1 d x 2 d f 2 ] = [ 1 3 2 4 ]
심화
A = [ 5 3 3 4 ] = [ a 11 a a a 22 ] A=\begin{bmatrix} 5 & 3 \\ 3 & 4 \end{bmatrix}=\begin{bmatrix} a_{11} & a \\ a & a_{22} \end{bmatrix} A = [ 5 3 3 4 ] = [ a 1 1 a a a 2 2 ]
이처럼 A가 symmetric 할 경우 다음이 성립된다
d d x x T A x = 2 A x \frac{d}{dx}x^TAx=2Ax d x d x T A x = 2 A x
증명
x T A x = ( x 1 , x 2 ) ( a 11 a a a 22 ) ( x 1 x 2 ) = ( x 1 , x 2 ) ( a 11 x 1 + a x 2 a 21 x 1 + a 22 x 2 ) x^TAx = (x_1,x_2)\begin{pmatrix} a_{11} & a\\ a& a_{22} \end{pmatrix}\begin{pmatrix} x_1 \\ x_2 \end{pmatrix} = (x_1,x_2)\begin{pmatrix} a_{11}x_1 + ax_2 \\ a_{21}x_1 + a_{22}x_2 \end{pmatrix} x T A x = ( x 1 , x 2 ) ( a 1 1 a a a 2 2 ) ( x 1 x 2 ) = ( x 1 , x 2 ) ( a 1 1 x 1 + a x 2 a 2 1 x 1 + a 2 2 x 2 )
= a 11 x 1 2 + 2 a x 1 x 2 + a 22 x 2 2 =a_{11}x^2_1 + 2ax_1x_2 + a_{22}x^2_2 = a 1 1 x 1 2 + 2 a x 1 x 2 + a 2 2 x 2 2
= f ( x 1 , x 2 ) =f(x_1,x_2) = f ( x 1 , x 2 )
d d x x T A x = ( d f d x 1 d f d x 2 ) = 2 ( a 11 x 1 + a x 2 a x 1 + a 22 x 2 ) = 2 ( a 11 a a a 22 ) ( x 1 x 2 ) = 2 A x \frac{d}{dx}x^TAx = \begin{pmatrix} \frac{df}{dx_1} \\ \frac{df}{dx_2} \end{pmatrix}=2\begin{pmatrix} a_{11}x_1 + ax_2\\ ax_1 + a_{22}x_2 \end{pmatrix}= 2\begin{pmatrix} a_{11} & a \\ a & a_{22} \end{pmatrix}\begin{pmatrix} x_1 \\ x_2 \end{pmatrix} = 2Ax d x d x T A x = ( d x 1 d f d x 2 d f ) = 2 ( a 1 1 x 1 + a x 2 a x 1 + a 2 2 x 2 ) = 2 ( a 1 1 a a a 2 2 ) ( x 1 x 2 ) = 2 A x
Calculus in Autograd
y ⃗ \vec{y} y 는 m개, x ⃗ \vec{x} x 는 n개의 변수를 가지고 있다.
Jacobian Matrix J J J
y ⃗ = f ( x ⃗ ) \vec{y}=f(\vec{x}) y = f ( x ) 에서 x ⃗ \vec{x} x 에 대한 y ⃗ \vec{y} y 의 변화도
J = ( ∂ y ∂ x 1 . . . ∂ y ∂ x n ) = ( ∂ y 1 ∂ x 1 . . . ∂ y 1 ∂ x n . . . ∂ y m ∂ x 1 . . . ∂ y m ∂ x 1 ) J = \begin{pmatrix} \frac{\partial y}{\partial x_1} & ... & \frac{\partial y}{\partial x_n} \end{pmatrix} = \begin{pmatrix} \frac{\partial y_1}{\partial x_1} & ... & \frac{\partial y_1}{\partial x_n} \\ & ... \\ \frac{\partial y_m}{\partial x_1} & ... & \frac{\partial y_m}{\partial x_1} \end{pmatrix} J = ( ∂ x 1 ∂ y . . . ∂ x n ∂ y ) = ⎝ ⎜ ⎛ ∂ x 1 ∂ y 1 ∂ x 1 ∂ y m . . . . . . . . . ∂ x n ∂ y 1 ∂ x 1 ∂ y m ⎠ ⎟ ⎞
torch.autograd
는 Vector-Jacobian 곱셈을 계산하는데
이는, 주어진 어떤 벡터 v ⃗ \vec{v} v 에 대하여 J T ⋅ v ⃗ {J^T}\cdot{\vec{v}} J T ⋅ v 연산을 한다는 의미다
스칼라 함수 l = g ( y ⃗ ) l=g(\vec{y}) l = g ( y ) 의 변화도(gradient)는
chian rule에 따라 l ′ = f ′ ( x ) g ′ ( f ( x ) ) = f ′ ( x ) g ′ ( y ) l^{\prime} = f^{\prime}(x)g^{\prime}(f(x)) = f^{\prime}(x)g^{\prime}(y) l ′ = f ′ ( x ) g ′ ( f ( x ) ) = f ′ ( x ) g ′ ( y ) 로 변환된다.
그리고 y y y 에 대한 l l l 의 변화도는 다음과 같다.
( ∂ l ∂ y 1 . . . ∂ l ∂ y m ) \begin{pmatrix} \frac{\partial l}{\partial y_1} & ... & \frac{\partial l}{\partial y_m} \end{pmatrix} ( ∂ y 1 ∂ l . . . ∂ y m ∂ l )
이에 대한 Transpose가 v ⃗ \vec{v} v 라고 하자
v ⃗ = ( ∂ l ∂ y 1 . . . ∂ l ∂ y m ) T \vec{v} = \begin{pmatrix} \frac{\partial l}{\partial y_1} & ... & \frac{\partial l}{\partial y_m} \end{pmatrix}^T v = ( ∂ y 1 ∂ l . . . ∂ y m ∂ l ) T
x x x 에 대한 l l l 의 변화도의 Transpose를 구한다는 것은 다음과 같이 표현할 수 있다.
( ∂ l ∂ x 1 . . . ∂ l ∂ x m ) = ( ∂ y 1 ∂ x 1 . . . ∂ y m ∂ x 1 . . . ∂ y 1 ∂ x n . . . ∂ y m ∂ x 1 ) ( ∂ l ∂ y 1 . . . ∂ l ∂ y m ) = J T ⋅ v ⃗ \begin{pmatrix} \frac{\partial l}{\partial x_1}\\ ... \\ \frac{\partial l}{\partial x_m} \end{pmatrix} = \begin{pmatrix} \frac{\partial y_1}{\partial x_1} & ... & \frac{\partial y_m}{\partial x_1} \\ & ... \\ \frac{\partial y_1}{\partial x_n} & ... & \frac{\partial y_m}{\partial x_1} \end{pmatrix}\begin{pmatrix} \frac{\partial l}{\partial y_1}\\ ... \\ \frac{\partial l}{\partial y_m} \end{pmatrix} ={J^T}\cdot{\vec{v}} ⎝ ⎜ ⎛ ∂ x 1 ∂ l . . . ∂ x m ∂ l ⎠ ⎟ ⎞ = ⎝ ⎜ ⎛ ∂ x 1 ∂ y 1 ∂ x n ∂ y 1 . . . . . . . . . ∂ x 1 ∂ y m ∂ x 1 ∂ y m ⎠ ⎟ ⎞ ⎝ ⎜ ⎛ ∂ y 1 ∂ l . . . ∂ y m ∂ l ⎠ ⎟ ⎞ = J T ⋅ v
Reference