CrossEntropyLoss
에서 가중치(weight)는 softmax를 적용한 확률이 아니라, 최종 손실 값에 곱해짐. 따라서 가중치는 로그 확률에 곱해서 손실을 조정하는 역할.
1. CrossEntropyLoss 공식
Loss=−N1i=1∑Nweight[yi]⋅log(p^i,yi)
여기서:
- N: 배치 내 샘플 수.
- yi: 샘플 i의 정답 클래스.
- p^i,yi: 모델이 예측한 샘플 i의 정답 클래스 확률 (softmax 값).
- weight[yi]: 정답 클래스 yi에 대응하는 가중치.
2. Softmax 계산
logits
를 사용하여 softmax를 계산합니다. Softmax는 모델이 출력한 raw logits를 확률 분포로 변환.
Softmax 공식:
p^i,j=∑kexp(logiti,k)exp(logiti,j)
예시: 주어진 logits
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3]])
labels = torch.tensor([0, 2])
weights = torch.tensor([1.0, 2.0, 3.0])
Softmax 계산:
-
샘플 1의 logits: [2.0, 1.0, 0.1]
p^1,0=exp(2.0)+exp(1.0)+exp(0.1)exp(2.0)=e2+e1+e0.1e2≈0.659
p^1,1=exp(2.0)+exp(1.0)+exp(0.1)exp(1.0)≈0.242
p^1,2=exp(2.0)+exp(1.0)+exp(0.1)exp(0.1)≈0.099
-
샘플 2의 logits: [0.5, 2.5, 0.3]
p^2,0=exp(0.5)+exp(2.5)+exp(0.3)exp(0.5)≈0.095
p^2,1=exp(0.5)+exp(2.5)+exp(0.3)exp(2.5)≈0.843
p^2,2=exp(0.5)+exp(2.5)+exp(0.3)exp(0.3)≈0.062
3. 손실 계산
모델의 출력 확률(softmax 값)을 정답 레이블과 비교하여 CrossEntropyLoss를 계산합니다.
3.1 로그 확률 계산:
- 샘플 1의 정답은 클래스 0:
−log(p^1,0)=−log(0.659)≈0.417
- 샘플 2의 정답은 클래스 2:
−log(p^2,2)=−log(0.062)≈2.785
3.2 클래스별 가중치 적용:
- 샘플 1의 정답 클래스 0의 가중치: (1.0)
가중치 적용 손실=1.0⋅0.417=0.417
- 샘플 2의 정답 클래스 2의 가중치: (3.0)
가중치 적용 손실=3.0⋅2.785=8.355
4. 최종 손실 계산
배치의 평균 손실을 계산합니다:
최종 손실=2샘플 1 손실+샘플 2 손실=20.417+8.355≈4.386
5. 요약
CrossEntropyLoss
의 weight
는 클래스별 가중치를 조정하여, 특정 클래스가 손실에 더 큰 영향을 미치도록 설계됨.
- Softmax 확률 계산 후 로그 확률을 손실로 변환하며, 이 손실에 클래스별 가중치가 곱해짐.
- 배치의 평균 손실이 최종 결과로 반환됨.
weight
를 적절히 설정하면 클래스 불균형 문제를 완화하거나, 특정 클래스의 중요도를 조정할 수 있음.