loss가 NaN으로 찍힐 경우(weight_norm이 NaN을 반환할 때)

Jinhyeong Park·2020년 11월 30일
1


pytorch의 weight normalization function인 weight_norm은 weight가 0으로 채워졌을 경우 NaN 반환한다.

이를 해결하기 위하여

def compute_weight(self, module):

        g = getattr(module, self.name + '_g')
        v = getattr(module, self.name + '_v')
        w = v*(g/(torch.norm_except_dim(v, 2, dim)+1e-8)).expand_as(v)

        return w

위의 방법으로 해결되지 않을 때, 대안으로 아래와 같이 작은 값을 더해 weight를 뽑아내는 일종의 트릭(?)을 사용할 수도 있다.

def layer_recalibration(self, layer):

        if(torch.isnan(layer.weight_v).sum() > 0):
            print ('recalibrate layer.weight_v')
            layer.weight_v = torch.nn.Parameter(torch.where(torch.isnan(layer.weight_v), torch.zeros_like(layer.weight_v), layer.weight_v))
            layer.weight_v = torch.nn.Parameter(layer.weight_v + 1e-7)

        if(torch.isnan(layer.weight).sum() > 0):
            print ('recalibrate layer.weight')
            layer.weight = torch.where(torch.isnan(layer.weight), torch.zeros_like(layer.weight), layer.weight)
            layer.weight += 1e-7

dense layer 전에 적용하면 된다.

def forward(self, x):

    x = self.batch_norm(x)
    x = self.dropout(x)

    self.recalibrate_layer(self.dense1)

    x = self.dense(x)
    x = F.relu(x)
    return x

loss 값이 정상적으로 출력되는 것을 확인 가능

reference
https://github.com/pytorch/pytorch/issues/19126

0개의 댓글