예시:
logits = model(x) # raw score
softmax
(multi-class) 또는 sigmoid
(multi-label)를 적용한 값입니다.특징
[0.1, 0.7, 0.2]
argmax
를 통해 얻습니다.예시:
Soft Prob: [0.1, 0.7, 0.2]
→ Hard Prob: [0, 1, 0]
특징
상황 | Logit | Soft Prob | Hard Prob |
---|---|---|---|
학습(Training) | 모델 출력(raw) | ✅ Loss 계산 | ❌ |
검증(Validation) | 가능 | 선택적 | ✅ 일반적으로 사용 |
테스트/배포(Inference) | ❌ | ❌ | ✅ 최종 예측 결과 |
def dice_loss(pred, target, smooth=1e-6):
pred_soft = F.softmax(pred, dim=1) # Soft probability
target_one_hot = F.one_hot(target, num_classes=pred.size(1)).permute(0, 2, 1).float()
intersection = (pred_soft * target_one_hot).sum(dim=2)
dice_score = (2 * intersection + smooth) / (pred_soft.sum(dim=2) + target_one_hot.sum(dim=2) + smooth)
return 1 - dice_score.mean()
def dice_score(pred_logits, target, smooth=1e-6):
pred_classes = torch.argmax(torch.softmax(pred_logits, dim=1), dim=1) # Hard prediction
pred_one_hot = F.one_hot(pred_classes, num_classes=pred_logits.size(1)).permute(0, 2, 1).float()
target_one_hot = F.one_hot(target, num_classes=pred_logits.size(1)).permute(0, 2, 1).float()
intersection = (pred_one_hot * target_one_hot).sum(dim=2)
dice_scores = (2 * intersection + smooth) / (pred_one_hot.sum(dim=2) + target_one_hot.sum(dim=2) + smooth)
return dice_scores.mean()
👉 따라서,