Focal loss 손실함수는 Focal Loss for Dense Object Detection 객체탐지 논문에서 처음으로 등장했다.
논문에서 Focal Loss는 one-stage detector에서 학습 중 클래스 불균형 문제를 해결하고자 고안되었다. 예로 학습 중 박스를 친 대다수의 이미지를 살펴보면 백그라운드 픽셀이 객체의 픽셀보다 많은 비중을 차지하고 있어서 classification에 통과시킨 값을 오분류로 분류해 학습에 방해가 된다고 한다.
Cross Entropy의 클래스 불균형 문제는 Detector에 의해 백그라운드로 분류될 수 있는 easy negative가 대부분이므로 학습에 비효율적이다.
Focal loss는 Cross Entropy의 클래스 불균형 문제를 개선한 개념이며 어렵거나 쉽게 오분류되는 케이스에 대하여 더 큰 가중치를 주는 방법이다.
Entropy loss위의 그림은 Binary Cross Entropy loss로 객체탐지 모델을 학습시켜 나온 결과이다. 위에서 말한 클래스 불균형으로 인해 확실하지 않은 객체에 대해서는 탐지를 전혀 못하고 있는 상황이다.
반면에 Focal loss 경우 클래스 불균형 문제를 극복해서 최대한 많은 객체를 탐지하고 있다.
Focal loss 에서는 마지막에 출력되는 각 클래스의 probability를 이용해 CE Loss에 통과된 최종 확률값이 큰 EASY 케이스의 Loss를 크게 줄이고 최종 확률 값이 낮은 HARD 케이스의 Loss를 낮게 줄이는 역할을 한다. 보통 CE는 확률이 낮은 케이스에 패널티를 주는 역할만 하고 확률이 높은 케이스에 어떠한 보상도 주지만 Focal Loss는 확률이 높은 케이스에는 확률이 낮은 케이스 보다 Loss를 더 크게 낮추는 보상을 주는 차이점이 있다.
class WeightedFocalLoss(nn.Module):
def __init__(self, alpha=.25, gamma=2):
super(WeightedFocalLoss, self).__init__()
self.alpha = torch.tensor([alpha, 1-alpha]).cuda()
self.gamma = gamma
def forward(self, inputs, targets):
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
targets = targets.type(torch.long)
at = self.alpha.gather(0, targets.data.view(-1))
pt = torch.exp(-BCE_loss)
F_loss = at*(1-pt)**self.gamma * BCE_loss
return F_loss.mean()