Focal Loss는 Object Detection 중 one-stage detector의 성능을 개선하기 위해 고안되었다고 한다. One-stage detector의 문제점은 학습 중 클래스 불균형 문제가 심하다는 것이다. 예를 들어, 배경에 대하여 박스를 친 것이 실제 객체에 대하여 박스를 친 것에 비해 너무 많다는 것이다. 배경의 비율이 크면, loss 계산 시 배경의 영향이 압도적으로 커지는 문제가 발생할 수 있다.
이러한 문제를 개선하기 위해, Focal Loss 개념이 도입되었다. Focal Loss는 Cross-entropy Loss의 클래스 불균형 문제를 다루기 위해 개선된 버전이고, 어렵거나 쉽게 오분류되는 케이스에 대하여 더 큰 가중치를 주는 방법을 사용한다. 반대로 쉬운 케이스의 경우, 낮은 가중치를 반영한다.
Cross-entropy Loss는 잘못 예측한 경우에 대하여 페널티를 부여하는 것에 초점을 둔다. Focal Loss에는 Cross-entropy loss에서 term이 추가되었다. 이 term의 역할은 분류하기 쉬운 example의 경우 loss의 가중치를 줄이기 위함이다.
논문에서 를 사용했다고 한다.
# Focal Loss Implementation
class FocalLoss(nn.Module):
def __init__(self, weight=None, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.weight = weight
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction=self.reduction)
pt = torch.exp(-ce_loss)
focal_loss = ((1-pt)**self.gamma * ce_loss).mean()
return focal_loss