본 포스트는 PyTorch 이용 시 Nan Loss 를 검출하는 방법과 Nan Loss 가 발생하는 이유와 간단한 해결책에 대한 내용을 담고 있습니다.
PyTorch 를 이용하여 모델을 학습 시킬 때, Loss 값이 Nan 이 되는 경우가 존재한다.
Loss 의 값이 Nan(or Infinite) 이 되면, backpropagation 과정에서 문제가 발생하게되며, 이는 곧 모델의 학습이 원활히 이루어지지 않음을 의미한다.
학습 과정에서 Loss 값이 Nan 이 되는 경우 학습을 자동으로 중단시킬 수 있다면, 불필요한 시간의 낭비를 줄일 수 있을 것이다.
PyTorch 에서는 Loss 값이 Nan 인지 아닌지를 판단하는 함수가 존재한다.
torch.isfinite(loss)# loss 가 finite 한 값이면 True, 아니면 False 반환
함수명에서 눈치 챌 수 있듯이 loss 가 finite 한 경우, 즉 loss 값이 유한한 값이라면 (numerically 표현 가능한 경우) True 를 반환하고 아닌 경우 False 를 반환한다!
아래는 위 함수를 이용하여 학습 과정에서 Loss 값을 검사하고 그 값이 Nan 또는 Infinite 인 경우 학습을 중단시키는 방법을 기술한 코드이다.
your_loss = ...
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ')
exit(1)
your_loss.backward()
optimizer.step()
Nan Loss 가 발생하는 이유는 굉장히 다양하다. 본 포스트에서는 Nan Loss 가 발생하는 경우에 대한 대표적인 원인과 해결방법을 소개한다.
높은 Learning Rate 값을 이용하는 경우 초기에 우리의 Loss 값에 따라 Backpropagation 시 weight 값이 크게 변하며 발산하게 되고 그에 따라 Loss 값 또한 커지며 발산하게 된다.
이런 경우에 낮은 Learning Rate 값으로 재학습 시키는 방법이 가장 간단한 해결책이다. 하지만, 높은 Learning Rate 값으로 학습을 시키고 싶다면 Learning Rate Warmup, Gradient Clipping, Batch Normalization 등의 기법을 적용하여 해당 문제를 완화시킬 수 있으니 시도해보길 바란다.
아래의 링크는 Learning Rate Warmup 을 구현한 코드이다.
Loss 의 계산과정 혹은 Forward 과정에서 Numerical Exception 이 발생하는 경우 Loss 값이 Nan 또는 Infinity 값이 될 수 있다. Numerical Exception 이란 컴퓨터가 계산할 수 없는 연산에 대한 예외로 대표적인 예로 Division by Zero, Log(0) 등의 연산이 있다. 이러한 경우는 디버깅을 통해 해결할 수 있다.