Source code of Autograd

Human Being·2023년 1월 13일
0

Pytorch

목록 보기
3/3
post-thumbnail

v\vec{v}를 pytorch 코드로 나타내려 한다.

스칼라 함수 l=g(y)l = g(\vec{y})l=2y1+3y2l = 2y_1 + 3y_2로 표현될 때
yy에 대한 ll의 변화도는 (ly1ly2)=(23)\begin{pmatrix} \frac{\partial l}{\partial y_1} & \frac{\partial l}{\partial y_2} \end{pmatrix}=\begin{pmatrix} 2 & 3 \end{pmatrix}이다
v=(23)\vec{v} = \begin{pmatrix} 2 \\ 3 \end{pmatrix}

y1=3x13+2x24y_1=3x_1^3+2x_2^4라고 할 때 y1x1=9x12\frac{\partial y_1}{\partial x_1} = 9x_1^2 이며 y1x2=8x23\frac{\partial y_1}{\partial x_2} = 8x_2^3 이므로

y2=2x12+x23y_2=2x_1^2+x_2^3라고 할 때 y2x1=4x1\frac{\partial y_2}{\partial x_1} = 4x_1 이며 y2x2=3x22\frac{\partial y_2}{\partial x_2} = 3x_2^2 이다.

yy에 대해서 backward()를 호출하면

x1=2, x2=3인 상황을 가정하여 x, y를 코드로 구현하여 각 편미분 공식에 대입한 결과와 자동 미분 결과를 확인하니 값이 일치함을 확인할 수 있다.

import torch

x1 = torch.tensor([2.], requires_grad=True)
x2 = torch.tensor([3.], requires_grad=True)
y1 = 3*x1**3 + 2*x2**4
y1.backward()
print(x1.grad, x2.grad)  ## tensor([36.]) tensor([216.])

###
 
x1 = 2
x2 = 3

y1x1 = 9*x1**2
y1x2 = 8*x2**3

print(y1x1, y1x2)  ## (36, 216)

다음으로 xx에 대한 ll의 변화도를 알아보자

xx에 대한 yy의 변화도인 JT=(9x124x18x233x22)J^T=\begin{pmatrix} 9x^2_1 & 4x_1 \\ 8x_2^3 & 3x_2^2 \end{pmatrix} 이기에

xx에 대한 ll의 변화도의 Transpose는 다음과 같다
(9x124x18x233x22)(23)=(18x12+12x116x23+9x22)\begin{pmatrix} 9x^2_1 & 4x_1 \\ 8x_2^3 & 3x_2^2 \end{pmatrix} \begin{pmatrix} 2 \\ 3 \end{pmatrix} = \begin{pmatrix} 18x_1^2 +12x_1 \\ 16x_2^3 + 9x^2_2 \end{pmatrix}

ll에 대해서 backward()를 실행한 결과 값이 일치함을 확인하였다

x1 = torch.tensor([2.], requires_grad=True)
x2 = torch.tensor([3.], requires_grad=True)

y1 = 3*x1**3 + 2*x2**4
y2 = 2*x1**2 + x2**3

l = 2*y1 + 3*y2

l.backward()


print(x1.grad, x2.grad)  ## tensor([96.]) tensor([513.])
print(y1.retain_grad(), y2.retain_grad())  ## None None
print(l.retain_grad())  ## None

####

lx1 = 18*x1**2 + 12*x1
lx2 = 16*x2**3 + 9*x2**2
print(lx1, lx2)  ## tensor([96.], grad_fn=<AddBackward0>) tensor([513.], grad_fn=<AddBackward0>)

하나 눈 여겨 볼 점으로 y1, y2, l에 gradient 값을 retain_grad()를 이용하여 가져온다.
이는 pytorch에서 backward의 결과값을 leaf-node의 것만 기록해두기 때문이다.


Reference

0개의 댓글