Intractable Posteriors
앞선 lecture에서 살펴본 대로 ELBO가 likelihood pθ(x)와 같아지는 조건은 q(z)=p(z∣x;lθ)일 때 뿐이다. 따라서 posterior를 구할 수 있으면 likelihood 역시 구할 수 있다. 그러나 posterior 역시 likelihood와 마찬가지로 대부분의 경우 intractable 하다. 이에 대한 대안으로 posterior를 approximation 하는 방법이 있다. q(z;ϕ)가 p(z∣x;θ) 가능한 가까워지도록 ϕ 를 선택하면 좋은 posterior에 대한 좋은 approximation을 얻을 수 있다. 이를 variational inference라고 한다.
The Evidence Lower bound applied to the entire dataset
q(z;ϕ)를 이용한 ELBO는 다음과 같다
- z∑q(z;ϕ)logp(z,x;θ)+H(q(z;ϕ))=L(x;θ,ϕ)
따라서 MLE의 경우를 생각하면 다음의 식을 얻을 수 있다.
- θmaxxi∈D∑logp(xi,θ)≥θ,ϕ1,⋯,ϕMmaxxi∈D∑L(xi;θ,ϕi)
위 식에서 주목 할 점은, ϕi이다. variational parameter ϕ는 모든 data point xi마다 다르다. 왜냐하면 true posterior p(z∣x;θ)가 모든 data point 마다 다르기 때문이다.
Learning Deep Generative Models
Gradient descent 알고리즘을 이용한 학습 시나리오를 생각해보자. 우선 ELBO의 식을 아래처럼 변경 가능하다.
- L(xi;θ,ϕi)=z∑q(z;ϕ)logp(z,x;θ)+H(q(z;ϕ))
=Eq(z;ϕi)[logp(z,xi;θ)−logq(z;ϕi)]
θ와 ϕi를 update 하기 위해 각각의 gradient를 구해야한다. 일반적으로 closed form 형태로 gradient를 구하기 어렵기 때문에 monte carlo sampling을 이용한다.
- Eq(z;ϕ)[logp(z,x;θ)−logq(z;ϕ)]≈K1k∑logp(zk,x;θ)−logq(zk;ϕ) (i는 식의 compactness를 위해 잠시 생략)
위 식으로 부터 얻고 싶은 gradient는 ∇θL(x;θ,ϕ)와 ∇ϕL(x;θ,ϕ)이다. ∇θL(x;θ,ϕ)는 다음과 같이 쉽게 구할 수 있다.
- ∇θEq(z;ϕ)[logp(z,x;θ)−logq(z;ϕ)]=Eq(z;ϕ)[∇θlogp(z,x;θ)]≈K1k∑∇θlogp(zk,x;θ) (q로 부터 z를 K 개 sampling.)
그러나 ∇ϕL(x;θ,ϕ)는 쉽게 구하기가 어렵다. 왜나하면 expectation 값이 ϕ에 관한 것이기 때문이다.
- ∇ϕEq(z;ϕ)[logp(z,x;θ)−logq(z;ϕ)]=Eq(z;ϕ)[∇ϕ(logp(z,x;θ)−logq(z;ϕ))]
따라서 approximation 할 다른 방법이 필요하다.
Reparameterization
z를 적절히 변환하면 z에 대한 gradient를 approximation 할 방법을 찾을 수 있다. 먼저 q(z;ϕ)=N(μ,σ2I)로 가정하면 다음 두 가지의 동일한 sampling 식을 얻을 수 있다.
- Sample z∼q(z;ϕ), ϕ=(μ,σ)
- Sample ϵ∼N(0,I),z=μ+σϵ=g(ϵ;ϕ) (z를 shift and rescale)
위 식에서 g는 deterministic 함수이다. 따라서 앞서 본 expectaion은 다음과 같이 쓸 수 있다.
- Ez∼q(z;ϕ)[r(z)]=∫q(z;ϕ)r(z)dz=Eϵ∼N(0,I)[r(g(ϵ;ϕ)]
- ∇ϕEz∼q(z;ϕ)[r(z)]=∇ϕEϵ[r(g(ϵ;ϕ))]=Eϵ[∇ϕr(g(ϵ;ϕ))]
따라서 앞서 본것 과는 다르게 ∇ϕ를 쉽게 approximation 할 수 있다.(monte carlo 이용)
- Eϵ[∇ϕr(g(ϵ;ϕ))]≈K1k∑∇ϕr(g(ϵk;ϕ)) (ϵ을 N(0,I)로 부터 K개 sampling)
다시 본래의 loss 식 L(x;θ,ϕ)을 생각하면 transformation 식이 r(z)가 아닌 r(z,ϕ) 임을 알 수 있다.
- Eq(z;ϕ)[logp(z,x;θ)−logq(z;ϕ)]=Eq(z;ϕ)[r(z,ϕ)]
조금 더 복잡해지긴 했지만, chain rule 이용하여 전과 같이 쉽게 gradient의 approximation을 구할 수 있다.
- Eq(z;ϕ)[r(z,ϕ)]=Eϵ[r(g(ϵ;ϕ),ϕ)]≈K1k∑r(g(ϵk;ϕ),ϕ), z=μ+σϵ=g(ϵ;ϕ)
Amortized Inference
앞서 설명한대로 variational parameters ϕ는 data point xi에 따라 다르다. 따라서 dataset이 커지게 되면 variational parameters를 학습하는데 무리가 간다. amortization을 이용하면 이러한 문제를 해결할 수 있다.
학습시 모든 variational parameters를 학습하는 것이 아닌, 하니의 parametic function fλ를 학습하는 것이다. fλ는 xi를 ϕi로 mapping 하는 함수이다.
- fλ:xi↦ϕi
따라서 posterior은 q(z;fλ(xi))로 생각될 수 있고 일반적으로 qϕ(z∣x)로 표기된다. 이제 fλ(일반적으로 neural network)를 학습함으로써 posterior를 쉽게 approximation 할 수 있다. 이를 amortized inference라고 한다.
Autoencoder Perspective
앞의 내용들을 총 집합하면 다음과 같은 loss 식을 얻는다.
- L(x;θ,ϕ)=Eqϕ(z∣x)[logp(z,x;θ)−logqϕ(z∣x)]
위 식의 p(z)을 더하고 빼서 다음과 같은 변형식을 얻을 수 있다.
- Eqϕ(z∣x)[logp(z,x;θ)−logp(z)+logp(z)−logqϕ(z∣x)]
=Eqϕ(z∣x)[logp(x∣z;θ)]−DKL(qϕ(z∣x)∣∣p(z))
위 식에서 첫 번째 term은 실제 x와 최대한 같아지도록 만든다. 즉 reconstruction loss로 생각 할 수 있다. 반면 두 번째 term은 qϕ에 의해 sampling 되는 z가 prior p(z)와 최대한 닮도록 만든다. 즉 regularization loss로 생각 할 수 있다.
Reference
cs236 Lecture 6