[Pytorch]CrossEntropyLoss(weight)

ma-kjh·2024년 12월 23일
0

Pytorch

목록 보기
24/25

CrossEntropyLoss에서 가중치(weight)softmax를 적용한 확률이 아니라, 최종 손실 값에 곱해짐. 따라서 가중치는 로그 확률에 곱해서 손실을 조정하는 역할.


1. CrossEntropyLoss 공식

Loss=1Ni=1Nweight[yi]log(p^i,yi)\text{Loss} = -\frac{1}{N} \sum_{i=1}^N \text{weight}[y_i] \cdot \log(\hat{p}_{i,y_i})

여기서:

  • NN: 배치 내 샘플 수.
  • yiy_i: 샘플 ii의 정답 클래스.
  • p^i,yi\hat{p}_{i,y_i}: 모델이 예측한 샘플 ii의 정답 클래스 확률 (softmax 값).
  • weight[yi]\text{weight}[y_i]: 정답 클래스 yiy_i에 대응하는 가중치.

2. Softmax 계산

logits를 사용하여 softmax를 계산합니다. Softmax는 모델이 출력한 raw logits를 확률 분포로 변환.

Softmax 공식:

p^i,j=exp(logiti,j)kexp(logiti,k)\hat{p}_{i,j} = \frac{\exp(\text{logit}_{i,j})}{\sum_k \exp(\text{logit}_{i,k})}

예시: 주어진 logits

logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3]])
labels = torch.tensor([0, 2])  # 샘플 1: 클래스 0, 샘플 2: 클래스 2
weights = torch.tensor([1.0, 2.0, 3.0])  # 클래스별 가중치
Softmax 계산:
  • 샘플 1의 logits: [2.0, 1.0, 0.1]

    p^1,0=exp(2.0)exp(2.0)+exp(1.0)+exp(0.1)=e2e2+e1+e0.10.659\hat{p}_{1,0} = \frac{\exp(2.0)}{\exp(2.0) + \exp(1.0) + \exp(0.1)} = \frac{e^2}{e^2 + e^1 + e^{0.1}} \approx 0.659
    p^1,1=exp(1.0)exp(2.0)+exp(1.0)+exp(0.1)0.242\hat{p}_{1,1} = \frac{\exp(1.0)}{\exp(2.0) + \exp(1.0) + \exp(0.1)} \approx 0.242
    p^1,2=exp(0.1)exp(2.0)+exp(1.0)+exp(0.1)0.099\hat{p}_{1,2} = \frac{\exp(0.1)}{\exp(2.0) + \exp(1.0) + \exp(0.1)} \approx 0.099
  • 샘플 2의 logits: [0.5, 2.5, 0.3]

    p^2,0=exp(0.5)exp(0.5)+exp(2.5)+exp(0.3)0.095\hat{p}_{2,0} = \frac{\exp(0.5)}{\exp(0.5) + \exp(2.5) + \exp(0.3)} \approx 0.095
    p^2,1=exp(2.5)exp(0.5)+exp(2.5)+exp(0.3)0.843\hat{p}_{2,1} = \frac{\exp(2.5)}{\exp(0.5) + \exp(2.5) + \exp(0.3)} \approx 0.843
    p^2,2=exp(0.3)exp(0.5)+exp(2.5)+exp(0.3)0.062\hat{p}_{2,2} = \frac{\exp(0.3)}{\exp(0.5) + \exp(2.5) + \exp(0.3)} \approx 0.062

3. 손실 계산

모델의 출력 확률(softmax 값)을 정답 레이블과 비교하여 CrossEntropyLoss를 계산합니다.

3.1 로그 확률 계산:

  • 샘플 1의 정답은 클래스 0:
    log(p^1,0)=log(0.659)0.417-\log(\hat{p}_{1,0}) = -\log(0.659) \approx 0.417
  • 샘플 2의 정답은 클래스 2:
    log(p^2,2)=log(0.062)2.785-\log(\hat{p}_{2,2}) = -\log(0.062) \approx 2.785

3.2 클래스별 가중치 적용:

  • 샘플 1의 정답 클래스 0의 가중치: (1.0)
    가중치 적용 손실=1.00.417=0.417\text{가중치 적용 손실} = 1.0 \cdot 0.417 = 0.417
  • 샘플 2의 정답 클래스 2의 가중치: (3.0)
    가중치 적용 손실=3.02.785=8.355\text{가중치 적용 손실} = 3.0 \cdot 2.785 = 8.355

4. 최종 손실 계산

배치의 평균 손실을 계산합니다:

최종 손실=샘플 1 손실+샘플 2 손실2=0.417+8.35524.386\text{최종 손실} = \frac{\text{샘플 1 손실} + \text{샘플 2 손실}}{2} = \frac{0.417 + 8.355}{2} \approx 4.386

5. 요약

  1. CrossEntropyLossweight는 클래스별 가중치를 조정하여, 특정 클래스가 손실에 더 큰 영향을 미치도록 설계됨.
  2. Softmax 확률 계산 후 로그 확률을 손실로 변환하며, 이 손실에 클래스별 가중치가 곱해짐.
  3. 배치의 평균 손실이 최종 결과로 반환됨.

weight를 적절히 설정하면 클래스 불균형 문제를 완화하거나, 특정 클래스의 중요도를 조정할 수 있음.

profile
거인의 어깨에 올라서서 더 넓은 세상을 바라보라 - 아이작 뉴턴

0개의 댓글