쉽게 이해하는 KL Divergence Loss와 VAE 예시

Bean·2025년 6월 17일
0

인공지능

목록 보기
51/123

KL Divergence란?

딥러닝에서 KL Divergence Loss는 모델의 출력 분포(예: softmax 결과)와 정답 분포(예: soft label) 사이의 차이를 측정하는 데 사용됩니다.
특히 soft label이나 Knowledge Distillation을 할 때 자주 등장하는 개념입니다.


기본 예시

클래스P (정답 분포)Q (예측 분포)
00.10.2
10.60.3
20.30.5

KL Divergence Loss 계산 공식:

KL(PQ)=iP(i)log(P(i)Q(i))KL(P||Q) = \sum_i P(i) \cdot \log \left( \frac{P(i)}{Q(i)} \right)
  • 클래스 0: 0.1 × log(0.1/0.2) = -0.06931
  • 클래스 1: 0.6 × log(0.6/0.3) = 0.41586
  • 클래스 2: 0.3 × log(0.3/0.5) = -0.15324

👉 합산하면
KL = -0.06931 + 0.41586 - 0.15324 ≈ 0.19331


🧪 PyTorch 예제

import torch
import torch.nn.functional as F

P = torch.tensor([0.1, 0.6, 0.3])
Q = torch.tensor([0.2, 0.3, 0.5])

loss = F.kl_div(Q.log(), P, reduction='sum')
print(loss.item())  # ≈ 0.19331

VAE에서의 KL Divergence

Variational Autoencoder (VAE)에서는 KL Divergence가 잠재 벡터(latent vector)의 분포를 정규분포로 맞추는 데 사용됩니다.

✔️ 수식

KL(q(zx)p(z))=12i=1d(μi2+σi2logσi21)KL(q(z|x) || p(z)) = \frac{1}{2} \sum_{i=1}^d (\mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1)
  • μ: 인코더가 예측한 평균 벡터
  • σ: 인코더가 예측한 표준편차 벡터
  • p(z): 목표 정규분포 N(0, I) / I: 단위 행렬

간단한 계산 예시

예) latent vector 차원 = 2

  • μ = [1.0, -1.0]
  • σ = [0.5, 1.5]

그러면,

KL=12(12+0.52log0.521)+12((1)2+1.52log1.521)KL = \frac{1}{2} (1^2 + 0.5^2 - \log 0.5^2 - 1) + \frac{1}{2} ((-1)^2 + 1.5^2 - \log 1.5^2 - 1)

계산하면 ≈ 1.5377


PyTorch 예제 (VAE)

import torch

mu = torch.tensor([1.0, -1.0])
logvar = torch.log(torch.tensor([0.5**2, 1.5**2]))

kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
print(kl_loss.item())  # ≈ 1.5377

정리

KL Divergence
: 예측 분포가 정답 분포에서 얼마나 벗어났는지 측정

VAE에서의 역할
: Latent Vector가 정규분포를 따르도록 유도하여 안정적 생성


profile
AI developer

0개의 댓글