2부터는 기본적으로 "Understanding Diffusion Models : A Unified Perspective"를 보고 수식이나 내용의 베이스를 잡았고 나머지는 레퍼런스에서 소개하는 사이트들을 참고해 살을 붙였다.
들어가기에 앞서 알면 단순하지만 모르면 헷갈리는 오늘 다룰 내용에서 중요한 개념인 marginalize에 대해서 짚고 넘어가 보자.
likelihood fuction
likelihood function은 관측된 sample을 가지고 모든 가능한 parameters value에 대해 확률을 계산하는 함수이다. (파라미터 공간에서 적분을 통해 확률을 계산한다.) 그래서 모델의 probability를 모델 스스로 이해할 수 있고, sample의 확률인 evidence를 참조한다.
Marginal
독립 항등분포 (independent identically distributed) data X=(x1,...,xn) where xi∼p(x∣θ)는 파라미터화된 확률 분포 θ를 따른다. θ는 random variable (i.e. θ∼p(θ∣α))
margianal likelihood는 확률 분포 p(X∣α)를 θ로 marginal out 해주자는 것이다. 이 말은 각 random variable에 대한 joint distribusion이 주어질 때, 각각의 variable의 distribution에 대해서 probability distribution을 구하는 것을 의미한다. 이 말은 즉, 내가 관심있는 변수에 영향을 주는 다른 변수들을 모두 찾아서 해당 변수들에서 내가 관심있는 변수에 영향을 주는 부분을 모두 더하는 것이다. (sum rule-합의 법칙 이라고도 함)
pX(x)=∫pX∣Z(x,z)dz=E[pX∣Z(x∣z)](1)
(extra: 생성 모델에 미리 이 개념을 적용시켜보자면, latent space z에서 입력 이미지 x과 연관된 모든 확률들을 찾아 더해서 x에 대한 확률을 만드는 것이다. 추후에 나오게 되니 이해가 되지 않으면 그런가 하고 넘어가자.)
expectation value의 정의에 따라서 E[pX∣Z(x∣z)]는 이렇게 쓸 수도 있다.
E[pX∣Z(x∣z)]=∫yf(z)pZ(z)dy(2)
Variational Inference
결국 우리가 하고자 하는 것은 관측 데이터에 대한 likelihood를 계산해서 각 variance에 대한 확률 분포를 예측하는 모델을 만들자는 것이다. 모델을 잘 만들면 새로운 입력에 대해서도 어떤 결과를 가지는지 알 수 있기 때문! 이는 다시 말하자면, latent variable의 posterior를 계산하는 것과 같다. (지난 포스팅 MAP 참고)
하지만 이를 계산하려니 문제가 생긴다. 왜냐면 실제로 우리는 모든 variance에 대한 정확한 분포를 알지 못하기 때문이다. 그래서 일단 관측치 x에 대해 주변부의 확률 분포(marginalize)를 계산한다.
chain rule에 의해서 p(x)는 다음처럼 전개될 수 있다.
p(x)=p(z∣x)p(x,z)(3)
이후 수식은 pX∣Z(z∣x) -> p(z∣x) 이런 식으로 쓸거임
우리는 앞선 marginal distribution에 대한 식 (1)과 (3)에 따라서 likelihood를 직접 계산하지 않고 근사화하여 추론할 수 있다. p(x)의 likelihood를 직접적으로 이용하는 대신 최소가 되는 evidence를 최대화하자. (ELBO)
ELBO (Evidence Lower BOund)
우리는 우리는 데이터 x에 대한 확률 분포 p의 likelihood를 직접 구하는 대신 ϕ로 파라미터화 된 qϕ(z∣x)를 최적화 하길 원한다.
(1)식의 양변에 log를 취하고 식을 전개하면 다음과 같은 추가식을 얻을 수 있다.
logp(x)=log∫p(x,z)dz=log∫qϕ(z∣x)p(x,z)qϕ(z∣x)=logEqϕ(z∣x)[qϕ(z∣x)p(x,z)]≥Eqϕ(z∣x)[qϕ(z∣x)log[p(x,z)]
- log∫qϕ(z∣x)p(x,z)qϕ(z∣x) : 1을 곱함 - 분자 분모에 값은 값을 곱함
- logEqϕ(z∣x)[qϕ(z∣x)p(x,z)] : Definition of Exceptation
- ≥Eqϕ(z∣x)[qϕ(z∣x)log[p(x,z)] : Jensen's Inequality
(3)식을 전개해보자.
logp(x)=logp(x)∫qϕ(z∣x)dz=∫qϕ(z∣x)(logp(x))dz=Eqϕ(z∣x)[logp(x)]=Eqϕ(z∣x)[logp(z∣x)qϕ(z∣x)p(x,z)qϕ(z∣x)]=Eqϕ(z∣x)[logqϕ(z∣x)p(x,z)]+Eqϕ(z∣x)[logqϕ(z∣x)p(z∣x]=Eqϕ(z∣x)[logqϕ(z∣x)p(x,z)]+DKL(qϕ(z∣x)∣∣p(z∣x))≥Eqϕ(z∣x)[logqϕ(z∣x)p(x,z)]
KL divergence는 언제나 0 이상이므로 이 식의 최소값은 Eqϕ(z∣x)[logqϕ(z∣x)p(x,z)]에 따라 결정되게 된다. (lower bound)
❤ Reference
wikipidia
blog