pytorch의 weight normalization function인 weight_norm은 weight가 0으로 채워졌을 경우 NaN 반환한다.
이를 해결하기 위하여
def compute_weight(self, module):
g = getattr(module, self.name + '_g')
v = getattr(module, self.name + '_v')
w = v*(g/(torch.norm_except_dim(v, 2, dim)+1e-8)).expand_as(v)
return w
위의 방법으로 해결되지 않을 때, 대안으로 아래와 같이 작은 값을 더해 weight를 뽑아내는 일종의 트릭(?)을 사용할 수도 있다.
def layer_recalibration(self, layer):
if(torch.isnan(layer.weight_v).sum() > 0):
print ('recalibrate layer.weight_v')
layer.weight_v = torch.nn.Parameter(torch.where(torch.isnan(layer.weight_v), torch.zeros_like(layer.weight_v), layer.weight_v))
layer.weight_v = torch.nn.Parameter(layer.weight_v + 1e-7)
if(torch.isnan(layer.weight).sum() > 0):
print ('recalibrate layer.weight')
layer.weight = torch.where(torch.isnan(layer.weight), torch.zeros_like(layer.weight), layer.weight)
layer.weight += 1e-7
dense layer 전에 적용하면 된다.
def forward(self, x):
x = self.batch_norm(x)
x = self.dropout(x)
self.recalibrate_layer(self.dense1)
x = self.dense(x)
x = F.relu(x)
return x
loss 값이 정상적으로 출력되는 것을 확인 가능