My model sometimes suffers from NaN problems. Here I want to clarify the way I solved the issues.
In somecases, sum of masks becomes 0. However, tensors can be divided by tensor(0) resulting to NaN values. So, to avoid that case I simply take care of that cases not to be delivered to model's parameters.
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
# clamp min value to 1e-7 to avoid log(0)
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
return loss
at some point model predicts totally wrong by predicting 0 prob to the answers. To prevent the case of log(0), I simply clamp minimum values to probability.
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
# clamp min value to 1e-7 to avoid log(0)
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
return loss
# pt[-30:]
tensor([1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 6.8668e-20,
1.0000e+00, 7.2376e-21, 9.1739e-16, 1.8064e-16, 7.5304e-17, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 8.4078e-45, 1.4013e-45,
1.4013e-45, 0.0000e+00, 1.8217e-44, 1.4013e-45, 2.8026e-45, 2.8026e-45,
1.4013e-45, 8.4078e-45, 1.4013e-44, 1.4013e-45, 4.8723e-42, 7.4269e-44],
device='cuda:0', grad_fn=<SliceBackward0>)
# loss[-30:]
tensor([ -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, 44.1250, -0.0000,
46.3750, 34.6250, 36.2500, 37.1250, inf, inf, inf,
inf, inf, 101.4872, 103.2789, 103.2789, inf, 100.7140,
103.2789, 102.5858, 102.5858, 103.2789, 101.4872, 100.9763, 103.2789,
95.1250, 99.3086], device='cuda:0', grad_fn=<SliceBackward0>)