cross-entropy, KL-divergence 정리

pyross·2024년 8월 29일

공부

목록 보기
2/5

우연치 않게 팡요랩 KL divergence 영상을 시청하게 되었는데 매우 유익해서 정리를 하려고 한다.

우선 간단하게 설명을 하자면
cross-entropy는 KL-divergence에서 나왔다.
KL-divergence부터 설명을 하고 진행을 하겠다.

정보 공부

KL-divergence는 2개의 분포의 추측 entropy와 실제 entropy 차이를 의미한다.
여기에서 entropy란?
정보를 표현하는 단위이다.

이때 정보를 표현하기 위해서는 예시를 들어 공부해보자
임의의 확률 x,y가 있을때 각 확률이 아래와 같다고 하자.
p(x)=0.999,p(y)=0.001p(x)=0.999, p(y)=0.001
이때 정보는 y가 더 많다.

왜??

정보 "해가 동쪽에서 떴어" 와 "해가 서쪽에서 떴어"가 존재할 때
해가 동쪽에서 뜨는 것은 매일 일어나는 일이고 매우 당연한 정보이다. 그렇기 때문에
빈도수가 많고 이에 담긴 정보는 적다.
그러나 해가 서쪽에서 뜨는 것은 엄청나게 적은 빈도를 가지고 있기에 매우 큰 정보를 담고 있다.

여기에서 우리는 추측을 할 수 있다.

  • 아! 정보의 양은 사건의 확률에 반비례 하는구나!

여기에서 이제 예시를 추가해서 분포를 합쳐보자
p(x,y)=0.999p(x,y)=0.999와 같이 두개의 독립인 분포를 동시에 고려하려면 어떻게 해야할까?
여기에서 각 확률변수는 독립이기 때문에 정보를 분리를 할 수 있어야 한다.
즉 정보량을 h로 표현할 때
p(x,y)=p(x)p(y)p(x,y)=p(x)*p(y)가 되어야하기에.
h(x,y)=h(x)+h(y)h(x,y)=h(x)+h(y)로 표현이 되어야한다.

  • 이를 잘 보여주는 함수는 log이다.
    log(xy)=logx+logy\log(x*y)=\log x+\log y이기 때문.

위의 정보의 양과 사건의 확률에 합쳐
우리는 정보를
h(x)=log2p(x)h(x)=-\log_2p(x)로 표현할 수 있다.(곱셈 분할이 가능하고 확률에 반비례)
정보를 이렇게 표현하기로 합의를 한것이다.

여기에서 1개의 문장에 담긴 평균적인 정보는 각 사건이 발생할 확률을 곱해

H(X)=xp(x)log2p(x)H(X)=-\sum_x p(x)\log_2p(x)로 표현이 되고
이는 entropy이다.

KL-divergence

이제 KL-divergence를 배울 준비가 완료되었다.

KL-divergence는 실제 정답 분포 p(x)p(x)와 내가 예측한 분포 q(x)q(x)의 차이를 나타내고자 하는 식이다.

이때 내가 예측한 분포가 q(x)q(x)일때 실제 분포가 p(x)p(x)인 경우
이때의 entropy를 구해보면 다음과 같다.
H(X)=xp(x)log2q(x)H(X)=-\sum_x p(x)\log_2q(x)
왜냐하면 사건의 발생할 확률(실제 분포)는 p(x)p(x)이고 내가 추측한 분포는 q(x)q(x)이기 때문에 위와 같다.
이때 위 entropy는 H(X)=xp(x)log2p(x)H(X)=-\sum_x p(x)\log_2p(x)보다 무조건 더 크게 된다.
예측한 분포가 실제 최적의 분포와 다르기 때문이다.

이렇게 예측과 실제가 다르게 되면 불필요한 noise entropy가 추가되게 되고 이 차이를 KL-divergence라고 부른다.


LKL=(xp(x)log2q(x))(xp(x)log2p(x))L_{KL}=(-\sum_x p(x)\log_2q(x)) -(-\sum_x p(x)\log_2p(x))가 되고
위 식을 정리하면
LKL=xp(x)log2q(x)p(x)L_{KL}=-\sum_xp(x)\log_2{\frac{q(x)}{p(x)}}가 되는 것이다.

cross-entropy

cross-entropy는 보통 classification에서 사용하는 loss이다.
예측 분포와 추측 분포의 차이를 줄이기 위해서 사용하는데
예측 분포는 가중치 θ\theta에 대한 결과이다.
이를 학습하기 위해서 LKL\nabla L_{KL}과 같이 가능한데
이때 LKL=(xp(x)log2q(x))(xp(x)log2p(x))L_{KL}=(-\sum_x p(x)\log_2q(x)) -(-\sum_x p(x)\log_2p(x))에서
오른쪽의 항은 p(x)p(x)로만 이루어져있기에 가중치가 없고 결정된 값이다.
그렇기 때문에 이를 제외해도 학습에서는 아무런 의미가 없다.

결국 LCE=xp(x)log2q(x)L_{CE}=-\sum_x p(x)\log_2q(x)으로 p(x)p(x)q(x)q(x)의 분포 차이를 줄이는 loss를 나타내게 된다.

0개의 댓글