[PyTorch] Autograd-03 : Practice02

olxtar·2022년 4월 3일
0

[PyTorch] Autograd

목록 보기
4/5
post-thumbnail

[PyTorch] Autograd-03 : Practice01 이어서




06. retain_graph

Graph가 무엇이냐?
DCG(Dynamic Computation Graph)
뭐 자세히는 모르겠고, 동적으로 즉, 계산 진행과 동시에 아래와 같은 graph를 생성하는 그런거 같다.

[!] backward() 함수는 Graph를 만든 후에 한 번만 호출하는 것을 가정하고 있다.
\therefore backward()를 한 번 이상 호출하면 오류가 나온다.


출처 : https://teamdable.github.io/techblog/PyTorch-Autograd




x = torch.tensor(5.0, requires_grad=True)
y = x**3
z = torch.log(y)

z.backward(retain_graph=True)      # graph를 유지하라!

print('x after backward', get_tensor_info(x))
print('y after backward', get_tensor_info(y))
print('z after backward', get_tensor_info(z))

z.backward()

print('x after 2backward', get_tensor_info(x))
print('y after 2backward', get_tensor_info(y))
print('z after 2backward', get_tensor_info(z))

>>>
x after backward requires_grad(True) is_leaf(True) retains_grad(False) grad_fn(None) grad(0.6000000238418579) tensor(tensor(5., requires_grad=True))
y after backward requires_grad(True) is_leaf(False) retains_grad(False) grad_fn(<PowBackward0 object at 0x7f86f95fc0a0>) grad(None) tensor(tensor(125., grad_fn=<PowBackward0>))
z after backward requires_grad(True) is_leaf(False) retains_grad(False) grad_fn(<LogBackward0 object at 0x7f86fcd5a7f0>) grad(None) tensor(tensor(4.8283, grad_fn=<LogBackward0>))
x after 2backward requires_grad(True) is_leaf(True) retains_grad(False) grad_fn(None) grad(1.2000000476837158) tensor(tensor(5., requires_grad=True))
y after 2backward requires_grad(True) is_leaf(False) retains_grad(False) grad_fn(<PowBackward0 object at 0x7f86f95e4f10>) grad(None) tensor(tensor(125., grad_fn=<PowBackward0>))
z after 2backward requires_grad(True) is_leaf(False) retains_grad(False) grad_fn(<LogBackward0 object at 0x7f86f95e4f40>) grad(None) tensor(tensor(4.8283, grad_fn=<LogBackward0>))
  • backward()
    "(거꾸로 올라가면서 얘의) Gradient 계산해"
  • backward(retain_graph=True)
    "Gradient 계산하긴 하는데! 만든 graph는 날리지말고 유지시켜"

# 위 코드의 실행결과에서 x.grad의 값이 어떻게 되는지만 확인해보자

x after backward grad(0.6000000238418579)

x after 2backward grad(1.2000000476837158)

z.backward(retain_graph=True)와 같이 graph를 유지시키면
즉, backward() 호출에 필요한...
즉, Gradient를 계산하기 위해 필요한...
자원들을 해제하지 않음.
따라서 backward() 호출을 한번 더 할 수 있으나
새로 값을 덮어씌우지 않고 더해버림!

[?] 어떤 값에 대한 최종 output의 Gradient를 중첩시킬 필요가 있을까?

  loss  weight\frac{\partial \;loss}{\partial \;weight}


따라서 backward()를 한 번 이상 호출할 때는
Gradient가 저장되는 곳, 즉 x.grad초기화 시켜준다.

\rightarrow x.grad.zero_()





07. Stem 추가(파생변수 추가)

'파생변수 추가' 라는 것이 맞는 표현인지는 모르겠음

지금까지 작성했던 코드에서는... 아래와 같은 파생 연산이 이루어졌었다.

x=5x{\color{gray}=5}
y=f(x)=x3y=f(x)=x^3
z=g(y)=ln  yz=g(y)=ln \;y
z=g(f(x))=ln  x3{\color{gray}z=g(f(x))=ln\;x^3}

xyzx\rightarrow y\rightarrow z
Leaf - Stem - Root

여기서 Stem을 하나 더 추가해보자.


x = torch.tensor(5.0, requires_grad=True)
y = x ** 3
w = x ** 2
z = torch.log(y) + torch.sqrt(w)

print('x', get_tensor_info(x))
print('y', get_tensor_info(y))
print('w', get_tensor_info(w))
print('z', get_tensor_info(z))

>>>
x requires_grad(True) is_leaf(True) retains_grad(False) grad_fn(None) grad(None) tensor(tensor(5., requires_grad=True))
y requires_grad(True) is_leaf(False) retains_grad(False) grad_fn(<PowBackward0 object at 0x7f86fcd41fd0>) grad(None) tensor(tensor(125., grad_fn=<PowBackward0>))
w requires_grad(True) is_leaf(False) retains_grad(False) grad_fn(<PowBackward0 object at 0x7f86f6621b50>) grad(None) tensor(tensor(25., grad_fn=<PowBackward0>))
z requires_grad(True) is_leaf(False) retains_grad(False) grad_fn(<LogBackward0 object at 0x7f86f6621520>) grad(None) tensor(tensor(4.8283, grad_fn=<LogBackward0>))
  1. requires_grad는 말단의 leaf라고 볼 수 있는 x에 켜주다보니 x로부터 파생된 변수 y,w,z 모두 requires_grad=True이다.
  2. is_leafxTrue이고 나머지는 False이다.
  3. grad_fn 에서는 yw 모두 x로부터 파생연산(제곱연산)된 것 이다보니 PowBackward Function이 할당되었다.

이제 z.backward()를 호출시킨 다음 어떻게 gradient가 계산되는지 보자!
# gradient 연산이 어떻게 이루어지는지 확인하기 위하여 y,w,z의 gradient를 보존하자.


y.retain_grad()    
w.retain_grad()
z.retain_grad()

z.backward()

print('x after backward', get_tensor_info(x))
print('y after backward', get_tensor_info(y))
print('w after backward', get_tensor_info(w))
print('z after backward', get_tensor_info(z))

>>>
x after backward requires_grad(True) is_leaf(True) retains_grad(False) grad_fn(None) grad(1.600000023841858) tensor(tensor(5., requires_grad=True))
y after backward requires_grad(True) is_leaf(False) retains_grad(True) grad_fn(<PowBackward0 object at 0x7f86fcd41fa0>) grad(0.00800000037997961) tensor(tensor(125., grad_fn=<PowBackward0>))
w after backward requires_grad(True) is_leaf(False) retains_grad(True) grad_fn(<PowBackward0 object at 0x7f86fcd5a9d0>) grad(0.10000000149011612) tensor(tensor(25., grad_fn=<PowBackward0>))
z after backward requires_grad(True) is_leaf(False) retains_grad(True) grad_fn(<AddBackward0 object at 0x7f86fcd5a2b0>) grad(1.0) tensor(tensor(9.8283, grad_fn=<AddBackward0>))

먼저 x,y,w,zx,y,w,z의 관계를 수식으로 적어보자.
x=5x{\color{gray}=5}
y=f(x)=x3y=f(x)=x^3
w=g(x)=x2w=g(x)=x^2
z=h(y,w)=ln  y  +  wz=h(y,w)=ln\;y\;+\;\sqrt{w}

거슬러 올라가면서 gradient는 어떻게 구할까?

  1. zz  \frac{\partial z}{\partial z}\;를 구한다.
  2. zy,zw  \frac{\partial z}{\partial y},\frac{\partial z}{\partial w}\;를 구한다.
  3. yx  \frac{\partial y}{\partial x}\;wx\frac{\partial w}{\partial x}를 구한다.
  4. 아래와 같이 Chain Rule을 사용한다.
zzzyyx  +  zzzwwx\frac{\partial z}{\partial z}\cdot \frac{\partial z}{\partial y} \cdot \frac{\partial y}{\partial x}\;+\;\frac{\partial z}{\partial z}\cdot \frac{\partial z}{\partial w} \cdot \frac{\partial w}{\partial x}

zzzyyx  +  zzzwwx{\color{gray}\frac{\partial z}{\cancel{\partial z}}\cdot \frac{\cancel{\partial z}}{\cancel{\partial y}} \cdot \frac{\cancel{\partial y}}{\partial x}\;+\;\frac{\partial z}{\cancel{\partial z}}\cdot \frac{\cancel{\partial z}}{\cancel{\partial w}} \cdot \frac{\cancel{\partial w}}{\partial x}}
  1. zz=1\frac{\partial z}{\partial z}=1

  2. zy=y(ln  y  +  w)=1y=1x3=1125=0.008\frac{\partial z}{\partial y}=\frac{\partial}{\partial y}(ln\;y\;+\;\sqrt{w})=\frac{1}{y}=\frac{1}{x^3}=\frac{1}{125}=0.008
    zw=w(ln  y  +  w)=12w12=12(x2)12=12x=0.1\frac{\partial z}{\partial w}=\frac{\partial}{\partial w}(ln\;y\;+\;\sqrt{w})=\frac{1}{2}w^{-\frac{1}{2}}=\frac{1}{2}(x^2)^{-\frac{1}{2}}=\frac{1}{2x}=0.1

  3. yx=3x2=75\frac{\partial y}{\partial x}=3x^2=75
    wx=2x=10\frac{\partial w}{\partial x}=2x=10

  4. 10.00875  +  10.110=1.61\cdot0.008\cdot75\;+\;1\cdot0.1\cdot10=1.6

이와 같이 xy,wzx\rightarrow y,w\rightarrow z일때
xx에 대한 zz의 Gradient는
xxyy를 통해 zz에 준 영향과
xxww를 통해 zz에 준 영향을 더해서 계산한다.

backward()에서 grad에 Gradient를 저장할 때 기존의 grad의 Gradient를 더하기 때문에 06.retain_graph 참고 이런 계산이 자연스럽게 이루어짐.
Convolutional Neural Network의 Convolution Filter처럼 한 Weight가 여러 계산에 Share되면서 계산되는 경우에, 이런 식으로 Gradient가 합산되면서 grad에 저장됩니다. 그렇다고 합니다...

여기서 중요한게 우리는 .backward()를 어떤 Tensor에 호출시켜줌으로서 해당 Tensor의 Gradient를 구하는 것 이다.
따라서 우리가 yw에 backward()를 호출시키지 않는 한 yw의 Gradient는 볼 수 없다. 계산이 되더라도...

  • 컴퓨터가 보여주는 것

zz=1.0\frac{\partial z}{\partial z}=1.0

zy=0.008\frac{\partial z}{\partial y}=0.008

zw=0.1\frac{\partial z}{\partial w}=0.1

  • 컴퓨터가 보여주지 않는 것

yx=75\frac{\partial y}{\partial x}=75

wx=10\frac{\partial w}{\partial x}=10





08. Leaf 추가

q = torch.tensor(3.0, requires_grad=True)
x = torch.tensor(5.0, requires_grad=True)
y = x ** q
z = torch.log(y)

print('q', get_tensor_info(q))
print('x', get_tensor_info(x))
print('y', get_tensor_info(y))
print('z', get_tensor_info(z))

>>>
q requires_grad(True) is_leaf(True) retains_grad(False) grad_fn(None) grad(None) tensor(tensor(3., requires_grad=True))
x requires_grad(True) is_leaf(True) retains_grad(False) grad_fn(None) grad(None) tensor(tensor(5., requires_grad=True))
y requires_grad(True) is_leaf(False) retains_grad(False) grad_fn(<PowBackward1 object at 0x000001C2AFE2A160>) grad(None) tensor(tensor(125., grad_fn=<PowBackward1>))
z requires_grad(True) is_leaf(False) retains_grad(False) grad_fn(<LogBackward0 object at 0x000001C2AFE2AEB0>) grad(None) tensor(tensor(4.8283, grad_fn=<LogBackward0>))

qx로 부터 y가 파생되고, 그 y로부터 z가 파생되었다.
따라서 qx의 is_leaf값이 True임을 확인할 수 있고...
y의 grad_fn값이 PowBackward1임을 볼 수 있다. [?] PowBackward '1'?

q,x,y,z의 관계를 수식으로 살펴보면 다음과 같다.

q=3  ,  x=5q=3\;,\;x=5

y=f(x,q)=xqy=f(x,q)=x^q

z=g(y)=ln  yz=g(y)=ln\;y

backward()를 호출하여 즉, 거슬러 올라가면서 Gradient는 어떻게 구할까?


q = torch.tensor(3.0, requires_grad=True)
x = torch.tensor(5.0, requires_grad=True)
y = x ** q
z = torch.log(y)

z.backward()

print('q_after_backward', get_tensor_info(q))
print('x_after_backward', get_tensor_info(x))
print('y_after_backward', get_tensor_info(y))
print('z_after_backward', get_tensor_info(z))

>>>
q_after_backward requires_grad(True) is_leaf(True) retains_grad(False) grad_fn(None) grad(1.6094380617141724) tensor(tensor(3., requires_grad=True))
x_after_backward requires_grad(True) is_leaf(True) retains_grad(False) grad_fn(None) grad(0.6000000238418579) tensor(tensor(5., requires_grad=True))
y_after_backward requires_grad(True) is_leaf(False) retains_grad(False) grad_fn(<PowBackward1 object at 0x000001C2AFE0D610>) grad(None) tensor(tensor(125., grad_fn=<PowBackward1>))
z_after_backward requires_grad(True) is_leaf(False) retains_grad(False) grad_fn(<LogBackward0 object at 0x000001C2AFE0DA90>) grad(None) tensor(tensor(4.8283, grad_fn=<LogBackward0>))

07. Stem 추가(파생변수 추가) 와는 다르게 여기서는 Leaf가 추가되었다.
따라서 두 개 의 Gradient 즉,
1. q에 대한 z의 Gradient =zq=\frac{\partial z}{\partial q}
2. x에 대한 z의 Gradient =zx=\frac{\partial z}{\partial x}
를 구하게 된다.

  1. zq=zzzyyq\frac{\partial z}{\partial q}=\frac{\partial z}{\partial z}\cdot\frac{\partial z}{\partial y}\cdot\frac{\partial y}{\partial q}
  2. zx=zzzyyx\frac{\partial z}{\partial x}=\frac{\partial z}{\partial z}\cdot\frac{\partial z}{\partial y}\cdot\frac{\partial y}{\partial x}

zz=1\frac{\partial z}{\partial z}=1
zy=y  ln  y=1y=1xq=153=0.008\frac{\partial z}{\partial y}=\frac{\partial}{\partial y}\;ln\;y=\frac{1}{y}=\frac{1}{x^q}=\frac{1}{5^3}=0.008

yq=qxq=xqln  x=53ln  51251.61201.25\frac{\partial y}{\partial q}=\frac{\partial}{\partial q}x^q=x^q\cdot ln\;x=5^3\cdot ln\;5 \simeq 125\cdot 1.61\simeq201.25

yx=qxq=qxq1=3531=75\frac{\partial y}{\partial x}=\frac{\partial}{\partial q}x^q=qx^{q-1}=3\cdot5^{3-1}=75

\therefore

zq=zzzyyq=1×0.008×201.25=1.61\frac{\partial z}{\partial q}=\frac{\partial z}{\partial z}\cdot\frac{\partial z}{\partial y}\cdot\frac{\partial y}{\partial q}=1\times0.008\times201.25=1.61
zx=zzzyyx=1×0.008×75=0.6\frac{\partial z}{\partial x}=\frac{\partial z}{\partial z}\cdot\frac{\partial z}{\partial y}\cdot\frac{\partial y}{\partial x}=1\times0.008\times75=0.6


[+] 지수함수의 미분법
y=ax  y=axln  ay=a^x \rightarrow\; y\prime=a^x\cdot ln\;a
y=af(x)  y=af(x)ln  af(x)y=a^{f(x)} \rightarrow\; y\prime =a^{f(x)}\cdot ln\;a\cdot f\prime(x)


q.gradx.grad를 출력해볼 필요도 없이
qxgrad에 각각 1.609...와 0.6이 할당되어 있는 것을 볼 수 있다.

# 노파심
print("z gradient w.r.t q:", q.grad)
print("z gradient w.r.t x:", x.grad)

>>>
z gradient w.r.t q: tensor(1.6094)
z gradient w.r.t x: tensor(0.6000)

아래의 그림을 꼭 살펴보자 (출처 : https://teamdable.github.io/techblog/PyTorch-Autograd)


q와 x의 정방향 흐름(파란색)부터 보고, 역방향 흐름(빨간색)을 보셈





09. Operation

참고 : https://hongl.tistory.com/206

PyTorch에서는 어떠한 기능을 하는 함수에 대해
forward(), backward()torch.autograd 모듈로 새롭게 정의할 수 있습니다.

torch.autograd.Function 클래스를 상속하여
@staticmethod를 이용하여
입력에 대한 함수의 동작을 forward() 함수에
함수 출력에 대한 기울기를 받아 / 입력에 대한 기울기를 계산하는 backward() 함수를
새롭게 정의합니다.

무슨 소릴까...천천히 하나하나씩 정리해보자

니가 작성한 것 : [Python] Instance, Class, and Static Methods



08. Leaf 추가 에서 qx라는 leaf 2개를 통해 y가 파생되었고   y=f(x,q)\rightarrow\;y=f(x,q)
ygrad_fnPowBackward1이라는 함수를 저장하고 있었다.

우리는 이 PowBackward1이 어떻게 작동하는지 관심이 있어서
직접 Pow 함수를 만들어서 살펴보려 한다.

class MyPow(torch.autograd.Function):
  @staticmethod
  def forward(ctx, input_1, input_2):
    ctx.save_for_backward(input_1, input_2)
    result = input_1 ** input_2
    return result

  @staticmethod
  def backward(ctx, grad_output):
    input_1, input_2 = ctx.saved_tensors
    grad_input_1 = grad_output * input_2 * input_1 ** (input_2 - 1)
    grad_input_2 = grad_output * input_1 ** input_2 * torch.log(input_1)
    print('input_1', input_1)
    print('input_2', input_2)
    print('grad_output', grad_output)
    print('grad_input_1', grad_input_1)
    print('grad_input_2', grad_input_2)
    return grad_input_1, grad_input_2

PyTorch에서 '연산자' 즉, Operation을 정의할 때 torch.autograd.Function를 상속하여 forward()backward()를 구현합니다.

forward()에는 ctx와 연산자에 전달되는 argument이 차례대로 전달되고 이것을 이용해서 연산자가 계산해야 될 계산을 한 후에 계산결과를 return합니다. 여기서 추가로 처리해줘야 할 것이 있는데, backward()에서 Gradient를 계산하기 위해서는 forward()의 연산 당시의 상태를 알고 있어야 하기 때문에, backward()에서 필요한 상태정보를 forward()에서 ctx.save_for_backward()를 호출하여 저장해줘야 합니다.


예를 들어서... 08. Leaf 추가에서의 코드처럼 (아래) 있을때

q=3  ,  x=5q=3\;,\;x=5

y=f(x,q)=xqy=f(x,q)=x^q

z=g(y)=ln  yz=g(y)=ln\;y

backward()yx=qxq1\frac{\partial y}{\partial x}=qx^{q-1}이라는 사실은 혼자서도 구할 수 있지만, Gradient를 구체적인 숫자로 계산하기 위해서는 forward()의 연산 당시의 구체적인 qqxx의 값을 알아야
yx=qxq1=3×531=75\frac{\partial y}{\partial x}=qx^{q-1}=3\times5^{3-1}=75 요렇게 구할 수 있습니다. 따라서 MyPow Class의 forward()backward() (static) method를 살펴보면...





class MyPow(torch.autograd.Function):
  @staticmethod
  def forward(ctx, input_1, input_2):
    ctx.save_for_backward(input_1, input_2)
    result = input_1 ** input_2
    return result
    
    
  @staticmethod
  def backward(ctx, grad_output):
    input_1, input_2 = ctx.saved_tensors
    grad_input_1 = grad_output * input_2 * input_1 ** (input_2 - 1)
    grad_input_2 = grad_output * input_1 ** input_2 * torch.log(input_1)

[+] context, ctx : n. 맥락, 배경상황


forward method

  • ctx : forward에서는 input1, input2를 backward용으로 저장해두는 곳 (메모장 느낌?)

  • input_1, input_2 : 이 MyPow 클래스는 지금 아래의 연산을 위해 존재하는 것이므로 각각 x=5x=5, q=3q=3 이다.
    forward method : y=f(x,q)=xqy=f(x,q)=x^q
    backward method : zyyx    zyyq\frac{\partial z}{\partial y}\cdot \frac{\partial y}{\partial {\color{orange}x}}\;\;\frac{\partial z}{\partial y}\cdot \frac{\partial y}{\partial {\color{pink}q}}

  • result : xq=53=75x^q=5^3=75

backward method

  • ctx : 앞서 forward에서 input1, input2 즉 xxqq를 저장해 둔 곳
  • grad_output : 아래 그림에서와 같이 MyPow Class의 backward method
    즉, \simeq 아래 그림의 PowBackwardzy\frac{\partial z}{\partial y}를 input으로 받아 연산을 수행하고 x{\color{orange}x}에게는 zyyx\frac{\partial z}{\partial y}\cdot\frac{\partial y}{\partial {\color{orange}x}}를 주고, q{\color{pink}q}에게는 zyyq\frac{\partial z}{\partial y}\cdot \frac{\partial y}{\partial {\color{pink}q}}를 전달해준다. 따라서 input으로 받은 zy\frac{\partial z}{\partial y} 값임.
    업로드중..

  • grad_input_1 : operation code를 보면 zyqxq1\frac{\partial z}{\partial y}\cdot qx^{q-1}이므로... zyyx\frac{\partial z}{\partial y}\cdot\frac{\partial y}{\partial {\color{orange}x}}이겠구나!
  • grad_input_2 : operation cdoe를 보면 zyxqln  x\frac{\partial z}{\partial y}\cdot x^q\cdot ln\;x이므로... zyyq\frac{\partial z}{\partial y}\cdot\frac{\partial y}{\partial {\color{pink}q}}이겠구나!



[!]여기서 Keypointforward() 연산 진행 시 ctx.save_for_backward(input_1, input_2)을 통해서 xx값과 qq값을 저장하고 backward() 연산에 ctx.saved_tensors를 통해서 넘겨준다는 것이다.


[+] Code 내 변수들에게 저장된 값은 아래와 같다.

input_1 tensor(5., requires_grad=True)
input_2 tensor(3., requires_grad=True)
grad_output tensor(0.0080)
grad_input_1 tensor(0.6000)
grad_input_2 tensor(1.6094)

input_1 : x=5x=5
input_2 : q=3q=3
grad_output : zzzy=y  ln  y=1y=153=0.008{\color{gray}\frac{\partial z}{\partial z}}\cdot\frac{\partial z}{\partial y}=\frac{\partial}{\partial y}\;ln\;y=\frac{1}{y}=\frac{1}{5^3}=0.008
grad_input_1 : zyyx=0.008qxq1=0.00875=0.6\frac{\partial z}{\partial y}\cdot \frac{\partial y}{\partial x}=0.008 \cdot qx^{q-1}=0.008 \cdot75=0.6
grad_input_2 : zyyq=0.008xqln  x  =0.00853ln  51.61\frac{\partial z}{\partial y}\cdot\frac{\partial y}{\partial q}=0.008\cdot x^q\cdot ln\;x\;=0.008\cdot 5^3\cdot ln\;5\simeq1.61





해당 파트인 09. Operation 맨 앞에서 아래의 글이 잘 이해가 되지 않았었지?

...
PyTorch에서는 어떠한 기능을 하는 함수에 대해
forward(), backward()torch.autograd 모듈로 새롭게 정의할 수 있습니다.

torch.autograd.Function 클래스를 상속하여
@staticmethod를 이용하여
입력에 대한 함수의 동작을 forward() 함수에
함수 출력에 대한 기울기를 받아 / 입력에 대한 기울기를 계산하는 backward() 함수를
새롭게 정의합니다.
...

[?] 아직 torch.autograd.Function을 상속한다는 것은 무슨말인지 모르겠음
[!] forward 메소드에는 입력에 대한 forward 방향 연산을 정의!
backward 메소드에는 해당 클래스함수가 출력한 것에 대한 기울기를 받아서 입력에 대한 기울기를 계산 즉, 위의 예시에서 MyPow 클래스는 무엇을 output 했고 무엇을 input 받았니?

y=f(x,q){\color{blue}y}=f({\color{red}x,q})
  • 출력(output) : yy
  • 입력(input) : xx, qq

\therefore
출력한 것에 대한 기울기 : y?\frac{\partial}{\partial y}?
입력한 것에 대한 기울기 : x?,  q?\frac{\partial}{\partial x}?,\;\frac{\partial}{\partial q}?



10. .no_grad()

09. Operation에서 보았듯이
연산함수 즉 Operation function (MyPow 클래스...)에는 forward와 backward 메소드가 있고 backward 메소드 호출을 대비하여 ctx.save_for_backward(), grad_fn 등을 준비하므로 상당한 양의 메모리를 사용하게 된다.
[+] 선행조건 : 연산되는 Tensor의 requires_grad=True

근데 Training 과정에서만 backward가 필요하지 Inference, 즉 추론(= Test, Validation)과정에서는 backward가 필요하지 않다.

따라서

with torch.no_grad()
	q = torch.tensor(3.0, requires_grad=True)
    x = torch.tensor(5.0, requires_grad=True)
    ...

이와 같이 torch.no_grad() context를 사용하여 해당 context 내에서 생성된 Tensor들의 requires_grad가 False로 설정되도록 할 수 있다. 이는 메모리 사용량을 크게 줄일 수 있다.

profile
예술과 기술

0개의 댓글