Focal Loss

박정재·2023년 3월 18일
0

DL Basics

목록 보기
9/9

Focal Loss는 Object Detection 중 one-stage detector의 성능을 개선하기 위해 고안되었다고 한다. One-stage detector의 문제점은 학습 중 클래스 불균형 문제가 심하다는 것이다. 예를 들어, 배경에 대하여 박스를 친 것이 실제 객체에 대하여 박스를 친 것에 비해 너무 많다는 것이다. 배경의 비율이 크면, loss 계산 시 배경의 영향이 압도적으로 커지는 문제가 발생할 수 있다.

이러한 문제를 개선하기 위해, Focal Loss 개념이 도입되었다. Focal Loss는 Cross-entropy Loss의 클래스 불균형 문제를 다루기 위해 개선된 버전이고, 어렵거나 쉽게 오분류되는 케이스에 대하여 더 큰 가중치를 주는 방법을 사용한다. 반대로 쉬운 케이스의 경우, 낮은 가중치를 반영한다.

Cross-entropy Loss는 잘못 예측한 경우에 대하여 페널티를 부여하는 것에 초점을 둔다. Focal Loss에는 Cross-entropy loss에서 α(1pt)γ\alpha (1-p_t)^\gamma term이 추가되었다. 이 term의 역할은 분류하기 쉬운 example의 경우 loss의 가중치를 줄이기 위함이다.

FL(pt)=α(1pt)γlog(pt)FL(p_t) = -\alpha (1-p_t)^\gamma log(p_t)

논문에서 α=0.25,γ=2\alpha = 0.25, \gamma = 2 를 사용했다고 한다.

# Focal Loss Implementation
class FocalLoss(nn.Module):
  def __init__(self, weight=None, gamma=2, reduction='mean'):
    super(FocalLoss, self).__init__()
    self.weight = weight
    self.gamma = gamma
    self.reduction = reduction

  def forward(self, inputs, targets):
    ce_loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction=self.reduction)
    pt = torch.exp(-ce_loss)
    focal_loss = ((1-pt)**self.gamma * ce_loss).mean()

    return focal_loss

Reference

profile
Keep on dreaming and dreaming

0개의 댓글