[Pytorch]#2 NaN

Clay Ryu's sound lab·2024년 3월 7일
0

Framework

목록 보기
42/49

My model sometimes suffers from NaN problems. Here I want to clarify the way I solved the issues.

divide by tensor 0

In somecases, sum of masks becomes 0. However, tensors can be divided by tensor(0) resulting to NaN values. So, to avoid that case I simply take care of that cases not to be delivered to model's parameters.

  def get_nll_loss(self, logits, target, mask):
    probs = logits.softmax(dim=-1)
    if probs.ndim == 3:
      probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
    if target.ndim == 2:
      target = target.flatten(0, 1) # [batch_size*seq_len]
    # clamp min value to 1e-7 to avoid log(0)
    pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
    loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
    loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
    loss = loss.sum() / mask.sum() # calculating mean loss considering mask
    return loss

log(0)

at some point model predicts totally wrong by predicting 0 prob to the answers. To prevent the case of log(0), I simply clamp minimum values to probability.

  def get_nll_loss(self, logits, target, mask):
    probs = logits.softmax(dim=-1)
    if probs.ndim == 3:
      probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
    if target.ndim == 2:
      target = target.flatten(0, 1) # [batch_size*seq_len]
    # clamp min value to 1e-7 to avoid log(0)
    pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
    loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
    loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
    loss = loss.sum() / mask.sum() # calculating mean loss considering mask
    return loss
# pt[-30:]
tensor([1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 6.8668e-20,
        1.0000e+00, 7.2376e-21, 9.1739e-16, 1.8064e-16, 7.5304e-17, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 8.4078e-45, 1.4013e-45,
        1.4013e-45, 0.0000e+00, 1.8217e-44, 1.4013e-45, 2.8026e-45, 2.8026e-45,
        1.4013e-45, 8.4078e-45, 1.4013e-44, 1.4013e-45, 4.8723e-42, 7.4269e-44],
       device='cuda:0', grad_fn=<SliceBackward0>)
# loss[-30:]
tensor([ -0.0000,  -0.0000,  -0.0000,  -0.0000,  -0.0000,  44.1250,  -0.0000,
         46.3750,  34.6250,  36.2500,  37.1250,      inf,      inf,      inf,
             inf,      inf, 101.4872, 103.2789, 103.2789,      inf, 100.7140,
        103.2789, 102.5858, 102.5858, 103.2789, 101.4872, 100.9763, 103.2789,
         95.1250,  99.3086], device='cuda:0', grad_fn=<SliceBackward0>)
profile
chords & code // harmony with structure

0개의 댓글