[pytorch] no_grad(), model.eval, requires_grad=False 의 차이

seong taek ·2022년 4월 22일
2

딥러닝 모델을 전이학습하거나 이미 학습된 모델로 inference 시에 자주 등장하는데 항상 정확한 의미를 모르고 있어서 간단하게 정리해봤습니다.

model.eval()

딥러닝 모델들은 학습과 추론에서 다르게 작동하는 Layer를 가지고 있는 경우가 있습니다. 대표적으로 BatchNorm, Dropout가 있으며 이런 Layer 들이 추론 과정에서 학습시와는 다르게 작동하게 해주는 것이 .eval() 입니다

#train
...
model.train() #train mode
...

#eval 
model.eval() #-> eval mode
...
with torch.no_grad():
	...

torch.no_grad()

위 코드에서 모델 추론 부분에서는 model.eval() 말고도 torch.no_grad()도 함께 사용했습니다. torch.no_grad()로 감싸진(with) 부분에서는 gradient 계산을 하지 않아 메모리/속도가 개선됩니다. 공식 문서에서는 loss.backward()를 호출하지 않는다면 사용할 것을 권장하고 있습니다.

requires_grad=False

보통 아래와 같이 많이 사용합니다.

for p in model.parametes():
	p.requires_grad=False

requires_grad=False 를 적용하면 모델의 특정 부분에 그라디언트 계산을 멈출 수 있습니다. torch.no_grad() 와 가장 큰 차이는 그라디언트를 저장은 한다는 것입니다. 따라서 모델의 특정 부분은 freeze 하고 나머지는 학습시키는 등의 전략을 사용할 때 사용합니다.

torch.no_grad() VS requires_grad=False 의 설명이 약간 부실한것 같은데 아래 링크로 가시면 더 자세한 설명이 있습니다.

LINK

profile
rucola-pizza

0개의 댓글