Cs236 Lecture6

JInwoo·2025년 1월 8일

cs236

목록 보기
4/15

Intractable Posteriors

앞선 lecture에서 살펴본 대로 ELBO가 likelihood pθ(x)p_{\theta}(\mathbf{x})와 같아지는 조건은 q(z)=p(zx;lθ)q(\mathbf{z})=p(\mathbf{z|x};l\theta)일 때 뿐이다. 따라서 posterior를 구할 수 있으면 likelihood 역시 구할 수 있다. 그러나 posterior 역시 likelihood와 마찬가지로 대부분의 경우 intractable 하다. 이에 대한 대안으로 posterior를 approximation 하는 방법이 있다. q(z;ϕ)q(\mathbf{z};\phi)p(zx;θ)p(\mathbf{z|x};\theta) 가능한 가까워지도록 ϕ\phi 를 선택하면 좋은 posterior에 대한 좋은 approximation을 얻을 수 있다. 이를 variational inference라고 한다.

The Evidence Lower bound applied to the entire dataset

q(z;ϕ)q(\mathbf{z};\phi)를 이용한 ELBO는 다음과 같다

  • zq(z;ϕ)logp(z,x;θ)+H(q(z;ϕ))=L(x;θ,ϕ)\underset{\mathbf{z}}{\sum}q(\mathbf{z};\phi)\log p(\mathbf{z,x};\theta) + H(q(\mathbf{z;\phi))}=\mathcal{L(\mathbf{x};\theta,\phi)}

따라서 MLE의 경우를 생각하면 다음의 식을 얻을 수 있다.

  • maxθxiDlogp(xi,θ)maxθ,ϕ1,,ϕMxiDL(xi;θ,ϕi)\underset{\theta}{\max}\underset{\mathbf{x}^i\in\mathcal{D}}{\sum}\log p(\mathbf{x}^i,\theta)\ge\underset{\theta,\phi^1,\cdots,\phi^M}{\max}\underset{\mathbf{x}^i\in\mathcal{D}}{\sum}\mathcal{L}(\mathbf{x}^i;\theta,\phi^i)

위 식에서 주목 할 점은, ϕi\phi^i이다. variational parameter ϕ\phi는 모든 data point xi\mathbf{x}^i마다 다르다. 왜냐하면 true posterior p(zx;θ)p(\mathbf{z|x};\theta)가 모든 data point 마다 다르기 때문이다.

Learning Deep Generative Models

Gradient descent 알고리즘을 이용한 학습 시나리오를 생각해보자. 우선 ELBO의 식을 아래처럼 변경 가능하다.

  • L(xi;θ,ϕi)=zq(z;ϕ)logp(z,x;θ)+H(q(z;ϕ))\mathcal{L}(\mathbf{x}^i;\theta,\phi^i)=\underset{\mathbf{z}}{\sum}q(\mathbf{z};\phi)\log p(\mathbf{z,x};\theta) + H(q(\mathbf{z;\phi))}
    =Eq(z;ϕi)[logp(z,xi;θ)logq(z;ϕi)]\qquad\qquad\quad=E_{q(\mathbf{z};\phi^i)}[\log p(\mathbf{z,x}^i;\theta)-\log q(\mathbf{z};\phi^i)]

θ\thetaϕi\phi^i를 update 하기 위해 각각의 gradient를 구해야한다. 일반적으로 closed form 형태로 gradient를 구하기 어렵기 때문에 monte carlo sampling을 이용한다.

  • Eq(z;ϕ)[logp(z,x;θ)logq(z;ϕ)]1Kklogp(zk,x;θ)logq(zk;ϕ)E_{q(\mathbf{z};\phi)}[\log p(\mathbf{z,x};\theta)-\log q(\mathbf{z};\phi)]\approx\frac{1}{K}\underset{k}{\sum}\log p(\mathbf{z}^k, \mathbf{x};\theta) -\log q(\mathbf{z}^{k};\phi) (ii는 식의 compactness를 위해 잠시 생략)

위 식으로 부터 얻고 싶은 gradient는 θL(x;θ,ϕ)\nabla_{\theta}\mathcal{L}(\mathbf{x};\theta,\phi)ϕL(x;θ,ϕ)\nabla_{\phi}\mathcal{L}(\mathbf{x};\theta,\phi)이다. θL(x;θ,ϕ)\nabla_{\theta}\mathcal{L}(\mathbf{x};\theta,\phi)는 다음과 같이 쉽게 구할 수 있다.

  • θEq(z;ϕ)[logp(z,x;θ)logq(z;ϕ)]=Eq(z;ϕ)[θlogp(z,x;θ)]1Kkθlogp(zk,x;θ)\nabla_{\theta}E_{q(\mathbf{z};\phi)}[\log p(\mathbf{z ,x};\theta) -\log q(\mathbf{z};\phi)]=E_{q(\mathbf{z};\phi)}[\nabla_{\theta}\log p(\mathbf{z ,x};\theta)]\approx\frac{1}{K}\underset{k}{\sum}\nabla_{\theta}\log p(\mathbf{z}^k,\mathbf{x};\theta) (qq로 부터 z\mathbf{z}KK 개 sampling.)

그러나 ϕL(x;θ,ϕ)\nabla_{\phi}\mathcal{L}(\mathbf{x};\theta,\phi)는 쉽게 구하기가 어렵다. 왜나하면 expectation 값이 ϕ\phi에 관한 것이기 때문이다.

  • ϕEq(z;ϕ)[logp(z,x;θ)logq(z;ϕ)]Eq(z;ϕ)[ϕ(logp(z,x;θ)logq(z;ϕ))]\nabla_{\phi}E_{q(\mathbf{z};\phi)}[\log p(\mathbf{z, x};\theta) - \log q(\mathbf{z};\phi)] \ne E_{q(\mathbf{z};\phi)}[\nabla_{\phi}(\log p(\mathbf{z, x};\theta) - \log q(\mathbf{z};\phi))]

따라서 approximation 할 다른 방법이 필요하다.

Reparameterization

z\mathbf{z}를 적절히 변환하면 z\mathbf{z}에 대한 gradient를 approximation 할 방법을 찾을 수 있다. 먼저 q(z;ϕ)=N(μ,σ2I)q(\mathbf{z};\phi)=\mathcal{N}(\mu,\sigma^2I)로 가정하면 다음 두 가지의 동일한 sampling 식을 얻을 수 있다.

  • Sample zq(z;ϕ)\mathbf{z}\sim q(\mathbf{z};\phi), ϕ=(μ,σ)\phi=(\mu,\sigma)
  • Sample ϵN(0,I),z=μ+σϵ=g(ϵ;ϕ)\epsilon\sim\mathcal{N}(0,I),\mathbf{z}=\mu+\sigma\epsilon=g(\epsilon;\phi) (z\mathbf{z}를 shift and rescale)

위 식에서 gg는 deterministic 함수이다. 따라서 앞서 본 expectaion은 다음과 같이 쓸 수 있다.

  • Ezq(z;ϕ)[r(z)]=q(z;ϕ)r(z)dz=EϵN(0,I)[r(g(ϵ;ϕ)]E_{\mathbf{z}\sim q(\mathbf{z};\phi)}[r(\mathbf{z})]=\int q(\mathbf{z};\phi)r(\mathbf{z})d\mathbf{z}=E_{\epsilon\sim\mathcal{N}(0,I)}[r(g(\epsilon;\phi)]
  • ϕEzq(z;ϕ)[r(z)]=ϕEϵ[r(g(ϵ;ϕ))]=Eϵ[ϕr(g(ϵ;ϕ))]\nabla_{\phi}E_{\mathbf{z}\sim q(\mathbf{z};\phi)}[r(\mathbf{z})]=\nabla_{\phi}E_{\epsilon}[r(g(\epsilon;\phi))]=E_{\epsilon}[\nabla_{\phi}r(g(\epsilon;\phi))]

따라서 앞서 본것 과는 다르게 ϕ\nabla_{\phi}를 쉽게 approximation 할 수 있다.(monte carlo 이용)

  • Eϵ[ϕr(g(ϵ;ϕ))]1Kkϕr(g(ϵk;ϕ))E_{\epsilon}[\nabla_{\phi}r(g(\epsilon;\phi))]\approx\frac{1}{K}\underset{k}{\sum}\nabla_{\phi}r(g(\epsilon^k;\phi)) (ϵ\epsilonN(0,I)\mathcal{N}(0, I)로 부터 KK개 sampling)

다시 본래의 loss 식 L(x;θ,ϕ)\mathcal{L}(\mathbf{x};\theta,\phi)을 생각하면 transformation 식이 r(z)r(\mathbf{z})가 아닌 r(z,ϕ)r(\mathbf{z},\phi) 임을 알 수 있다.

  • Eq(z;ϕ)[logp(z,x;θ)logq(z;ϕ)]=Eq(z;ϕ)[r(z,ϕ)]E_{q(\mathbf{z};\phi)}[\log p(\mathbf{z, x};\theta) - \log q(\mathbf{z};\phi)]=E_{q(\mathbf{z};\phi)}[r(\mathbf{z,\phi})]

조금 더 복잡해지긴 했지만, chain rule 이용하여 전과 같이 쉽게 gradient의 approximation을 구할 수 있다.

  • Eq(z;ϕ)[r(z,ϕ)]=Eϵ[r(g(ϵ;ϕ),ϕ)]1Kkr(g(ϵk;ϕ),ϕ), z=μ+σϵ=g(ϵ;ϕ)E_{q(\mathbf{z};\phi)}[r(\mathbf{z,\phi})]=E_{\epsilon}[r(g(\epsilon;\phi),\phi)]\approx\frac{1}{K}\underset{k}{\sum}r(g(\epsilon^k;\phi),\phi),\ \mathbf{z}=\mu+\sigma\epsilon=g(\epsilon;\phi)

Amortized Inference

앞서 설명한대로 variational parameters ϕ\phi는 data point xi\mathbf{x}^i에 따라 다르다. 따라서 dataset이 커지게 되면 variational parameters를 학습하는데 무리가 간다. amortization을 이용하면 이러한 문제를 해결할 수 있다.

학습시 모든 variational parameters를 학습하는 것이 아닌, 하니의 parametic function fλf_\lambda를 학습하는 것이다. fλf_\lambdaxi\mathbf{x}^iϕi\phi^i로 mapping 하는 함수이다.

  • fλ:xiϕif_\lambda:\mathbf{x}^i\mapsto\phi^i

따라서 posterior은 q(z;fλ(xi))q(\mathbf{z};f_\lambda(\mathbf{x}^i))로 생각될 수 있고 일반적으로 qϕ(zx)q_\phi(\mathbf{z|x})로 표기된다. 이제 fλf_\lambda(일반적으로 neural network)를 학습함으로써 posterior를 쉽게 approximation 할 수 있다. 이를 amortized inference라고 한다.

Autoencoder Perspective

앞의 내용들을 총 집합하면 다음과 같은 loss 식을 얻는다.

  • L(x;θ,ϕ)=Eqϕ(zx)[logp(z,x;θ)logqϕ(zx)]\mathcal{L}(\mathbf{x};\theta,\phi)=E_{q_\phi(\mathbf{z|x})}[\log p(\mathbf{z,x};\theta) -\log q_\phi(\mathbf{z|x})]

위 식의 p(z)p(\mathbf{z})을 더하고 빼서 다음과 같은 변형식을 얻을 수 있다.

  • Eqϕ(zx)[logp(z,x;θ)logp(z)+logp(z)logqϕ(zx)]E_{q_\phi(\mathbf{z|x})}[\log p(\mathbf{z,x};\theta)-\log p(\mathbf{z})+\log p(\mathbf{z}) -\log q_\phi(\mathbf{z|x})]
    =Eqϕ(zx)[logp(xz;θ)]DKL(qϕ(zx)p(z))= E_{q_\phi(\mathbf{z|x})}[\log p(\mathbf{x|z};\theta)]-D_{KL}(q_\phi(\mathbf{z|x)}||p(\mathbf{z}))

위 식에서 첫 번째 term은 실제 x\mathbf{x}와 최대한 같아지도록 만든다. 즉 reconstruction loss로 생각 할 수 있다. 반면 두 번째 term은 qϕq_\phi에 의해 sampling 되는 z\mathbf{z}가 prior p(z)p(\mathbf{z})와 최대한 닮도록 만든다. 즉 regularization loss로 생각 할 수 있다.

Reference

cs236 Lecture 6

profile
Jr. AI Engineer

0개의 댓글