[논문 리뷰] Auto-Encoding Variational Bayes

VAE

Variational Auto- Encoder는 인코더를 거치면 평균, 표준편차의 두개의 벡터를 아웃풋으로 내게 되는데, 이 두가지를 결합해서 어떤 Gaussian Distribution을 만들게 되고,

그렇게 만든 분포에서 Sampling을 통해 Z를 만들고 그 Z가 다시 decoder를 통과해 어떤 기존의 input 데이터와 유사한 새로운 데이터들을 생성 할 수 있다.

이 VAE는 해당 확률 분포들을 이용해서 어떤 새로운 데이터를 생성하는 것을 목적으로 개발된 모델이므로 생성형 모델이라고 할 수 있다.

기존 Variational Bayesian (이하 VB) 방법들은 두가지 문제점이 있는데.

  • 1: 해당 생성을 위해서 θ\theta 의 선택에 대해 xx가 있을때 LatentValueLatent\, Value ZZ에 대한 근사가 필요한데 변분 하한에 포함된 기댓값인

    Eqϕ(zx)[logqϕ(zx)]E_{q\phi(z|x)}[log_{q\phi}(z|x)]\,

    를 풀기위해 해당 pθ(z)pθ(xz)dz\int p_\theta (z)p_\theta(x|z)dz를 계산해야 하는데 해당 Posterior density인 pθ(zx)p_\theta(z|x)intractable하므로 해당 기댓값을 해석적으로 계산하기 어렵다.

  • 2 : 대용량 데이터 셋에 대해서 기존 방법들이 비효율적이다.
    와 같은 두가지 문제점이 있었다.

따라서 해당 논문은 다음과 같은 해결책을 제시한다.
1. Reparameterization Trick : 확률적 노드에 대한 역전파를 가능하게 함
2. SGVB(Stochastic Gradient Variational Bayes): 확장가능한 근사 추론 알고리즘 개발
3. AEVB(Auto-Encoding Variational Bayes): 연속 잠재변수를 가진 생성모델의 효율적 학습 프레임 워크
4. VAE(Variational AutoEncoder): 신경망 기반의 실용적 구현 제시

결국 구현하고자 하는 것은

연속적인 Latent Variable과 Parameter가 다루기 힘든 사후분포를 갖는 Directed Probabilistic Model을 통해 효율적으로 Approximate Inference를 시행!

이제 구체적으로 알고리즘이 어떻게 되는지 Method 부분을 세세하게 나누어 확인해보자.

Method(2)

2.1 Problem Scenario


XX: 연속적인, 혹은 이산적인 변수 x,x, NN개의 i.i.d 샘플을 포함하는 데이터 셋
zz: 관측되지 않은 연속형 랜덤 변수 (Latent Variable)
xx: zz를 포함하는 어떤 random process에 의해 생성된 변수 (사전 분포 p(z)p(z)에 의해 생성됨)

이라고 했을때, 앞서말한 2가지 문제점에 대해서 해당 논문은 3가지 해결책을 제시했다.

  1. 파라미터 θ\theta에 대한 효과적인 근사치인 ML 또는 MAP추정.
    ML: Maximum Likelihood :어떤 상황이 주어졌을 때 우도를 최대화
    MAP: ML이 Likelihood를 최대화하는 방법이라면, MAP은 Posterior를 최대화 시키는 방법, 이때 Bayes Rule을 적용
  2. 관측값 xx가 주어졌을 때 잠재변수 zz의 효과적인 근사치 추정
  3. xx의 효율적인 Approximate Marginal inference

이때, 우리가 알고자 하는 pθ(x)p_\theta(x)를 직접 알아낼 수는 없으므로 controlable한 Latent Variable zz를 활용하여 pθ(x)p_\theta(x)를 간접적으로 푼다. (결합확률 분포를 사용)

pθ(x)=pθ(x,z)dz=pθ(z)p(xz)dzp_\theta(x) =\int p_\theta(x,z)dz =\int p_\theta(z)p(x|z)dz

여기서 Intractable한 pθp_\theta를 알기위해 Variational inference를 사용한다.

해당 추론은 logpθ(x)log\,p_\theta(x), 즉 generative model의 parameter θ\theta 하에서 관측된 xx의 주변 로그 우도를 최대화 하는 방식으로 진행된다.

logpθ(x)=근사/실제분포차이+ELBOlog\,p_\theta(x) = 근사/실제\,분포차이+ELBO

해당식을 전개해보면 다음과 같다.

logpθ(x)=DKL(qϕ(zx)pθ(zx)+L(θ,ϕ;x)log\,p_\theta(x) =D_{KL}(q_\phi(z|x)||p_\theta(z|x)\,+\,L(\theta,\phi;x)

DKL(qϕ(zx)pθ(zx)D_{KL}(q_\phi(z|x)||p_\theta(z|x): 근사 사후 분포 (qϕ(zx)q_\phi(z|x)와 실제 사후분포 pθ(zx)p_\theta(z|x) 사이의 차이

L(θ,ϕ;x)L(\theta,\phi;x): 변분 하한(Variational LowerBound), 혹은 증거 하한(Evidence Lower Bound; ELBO,

DKLD_{KL} 는 항상 양수이므로 (jensenss  Inequality 증명에 의거)jensens's~~ Inequality ~증명에 ~의거)
logpθ(x)L(θ,ϕ;x)log\,p_\theta(x) \geq L(\theta,\phi;x) 가 성립한다. 따라서 LL을 최적화(최대화) 하면 KL을 최소화 하면서 해당 logpθ(x)log\,p_\theta(x)를 최대화 할 수 있다.

해당 ELBO의 식을 보면 다음과 같다.

L(θ,ϕ;x)=Eqϕ(zx)[logqϕ(zx)+logpθ(x,z)]L(\theta,\phi;x) = E_{q\phi(z|x)}[-log\,q_\phi(z|x)\,+\,log\,p_\theta(x,z)]

이때 ,BayesRuleBayes\, Rule에 의거한 결합분포 분해를 적용,

BayseRule:logpθ(x,z)=logpθ(xz)+logpθ(z)Bayse\, Rule:log\,p_\theta(x,z) =log\,p_\theta(x|z) +log\,p_\theta(z)

해당 공식을 위에 대입하면

Eqϕ(zx)[logpθ(xz)]+Eqϕ(zx)[logpθ(z)]Eqϕ(zx)[logqθ(zx)],E_{q\phi(z|x)}[log\,p_\theta(x|z)]\,+\,E_{q\phi(z|x)}[log\,p_\theta(z)]\,-E_{q\phi(z|x)}[log\,q_\theta(z|x)],

두 확률분포 q,pq,p에 대한 KLdivergenceKL-divergence는 다음과 같이 정의되므로

DKL=Ezq[logq(z)logp(z)]D_{KL}=E_{z\sim q}[log\,q(z)-log\,p(z)]

위의 2번째 항과 3번째 항을 DKLD_{KL}로 묶을 수 있다.

=DKL(qϕ(zx)pθ(z))=-D_{KL}(q_\phi(z |x)||p_\theta(z))

이제 남은 첫번째 항과 해당식을 결합하면 다음과 같은 식을 얻는다

ELBO=Eqϕ(zx)[logpθ(xz)]DKL(qϕ(zx)pθ(z))ELBO =E_{q\phi(z|x)}[log\,p_\theta(x|z)]\,-D_{KL}(q_\phi(z |x)||p_\theta(z))

이때 앞의 기댓값 항이 음의 재구성 오차이며,(Latent variable로 복구된 결과 x에 대한 기댓값)
뒤의 KL항이 잠재공간 분포가 사전 분포와 유사하도록 하는 정규화 부분이다.

2.3 SGVB & AEVB algorithm

2.2에서 문제를 정의를 했다면 이제 해당 알고리즘을 풀어야 하는데,
Latent Variable zz를 직접 샘플링 할 경우 indiferetiable하므로 역전파가 불가능하다.

따라서 이를 위해 Reparameterization Trick을 사용한다.

  • Reparameterization Trick

Eqϕ(zx)[logpθ(xz)]E_{q\phi(z|x)}[log\,p_\theta(x|z)]의 식을 보면 해당 기댓값의 분포 qϕ(zx)q_\phi(z|x)ϕ\phi에 의존하지만 logpθ(xz)log\,p_\theta(x|z)ϕ\phi 에 의존하지 않으므로 Sampling이 Stochastic Result로 나온다.

따라서 연구자들이 제시한 방법은 분포 qϕ(zx)q_\phi(z|x)에서 샘플링하는 것이 아닌
deterministic한 분포에서 샘플링 후 gϕg_\phi로 변환하는 것! (주로 가우시안을 쓴다)

이렇게 재파라미터화 방법을 취하면 ϕ\phi에 의해 샘플링이 되는 것은 같지만 결정경로를 gϕg_\phi로 표현할 수 있다. 즉 잠재변수 Sampling Processfmf Random이 아닌 명시적 함수로서 표현 가능하다.

수식적으로 비교해보면, 기존 방식

zqϕ(zx)=N(μϕ(x),diag(σϕ2(x))z\sim q_\phi(z|x) = \mathcal{N}(\mu_\phi(x),\,diag(\sigma^2_\phi(x))

재파라미터화 (가우시안)

z=gϕ(ϵ,x)=μϕ(x)+σϕ(x)ϵz=g_\phi(\epsilon,x) =\mu_\phi(x) + \sigma_\phi(x)\odot\epsilon

따라서 zz가 랜덤한 샘플링에서 명시적인 함수가 되므로 zzϕ\phi에 대한 미분가능한 함수가 되므로 역전파가 가능해진다.

  • Monte Carlo
    재파라미터화를 진행하고, L(θ,ϕ;x)L(\theta,\phi;x)를 풀어보면 logpθ(xz)logp_\theta(x|z)항 (재구성 오차)항의 기댓값을 구해야 한다. 기존 방법으로는 해석적 계산 / 수치적 적분 방법이 있다.
  1. 해석적 계산
    RdN(z;μϕ(x),σϕ2(x))logpθ(xz)dz\int_{\mathbb{R}^d} \mathcal{N}(z; \mu_\phi(x), \sigma_\phi^2(x)) \log p_\theta(x|z) dz
    하지만 Latent value zz가 고차원이고, pθ(xz)p_\theta(x|z)가 복잡한 신경망이므로 해석적 해가 존재 하지않는다.
  2. 수치적 적분
    격자 기반 적분의 계산 복잡도가 O(Nd)∝ O(N^d)에 수렴하고, 가우시안 구적법 또한 여전히 고차원에서 비효율적이다. 따라서 해당 논문은 Monte Calro근사를 활용한다.

몬테카를로 근사의 식은 다음과 같다

Eqϕ(zx)[logpθ(xz)]1Ll=1Llogpθ(xz(l))\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] \approx \frac{1}{L} \sum_{l=1}^{L} \log p_\theta(x|z^{(l)})

재파라미터화로 샘플링 한 후 몬테카를로 근사를 적용하면
최종 ELBO는 다음과 같다.

in    ϵlN(0,I)  ,zl=gϕ(ϵ,x)=μϕ(x)+σϕ(x)ϵl,in~~~~\epsilon^{l}\sim\mathcal{N}(0,I)~~, z^{l}=g_\phi(\epsilon,x) =\mu_\phi(x) + \sigma_\phi(x)\odot\epsilon^{l},
ELBO:    L~=1Ll=1Llogpθ(xz(l))DKL(qϕ(zx)p(z))ELBO:~~~~\widetilde{\mathcal{L}} = \frac{1}{L} \sum_{l=1}^{L} \log p_\theta(x|z^{(l)}) - D_{KL}(q_\phi(z|x) \| p(z))

따라서 재구성오차의 그래디언트 또한 근사가 가능하므로

  1. 재구성오차의 그래디언트
    ϕEqϕ(zx)[logpθ(xz)]ϕlogpθ(xz)\nabla_\phi \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] \approx \nabla_\phi \log p_\theta(x|z)
  2. zz에 대한 편미분
    zϕ=μϕ(x)ϕ+σϕ(x)ϕϵ\frac{\partial z}{\partial \phi} = \frac{\partial \mu_\phi(x)}{\partial \phi} + \frac{\partial \sigma_\phi(x)}{\partial \phi} \odot \epsilon
  3. 연쇄법칙
    ϕlogpθ(xz)=logpθ(xz)zzϕ\nabla_\phi \log p_\theta(x|z) = \frac{\partial \log p_\theta(x|z)}{\partial z} \frac{\partial z}{\partial \phi}
  4. 전체 ELBO의 그래디언트
    ϕL~=ϕlogpθ(xz)ϕDKL(qϕ(zx)p(z))\nabla_\phi \widetilde{\mathcal{L}} = \nabla_\phi \log p_\theta(x|z) - \nabla_\phi D_{KL}(q_\phi(z|x) \| p(z))
    와 같이 구할 수 있다.

추가로, 가우시안 분포로 설정한다면 L은 다음과 같다.

L(θ,ϕ;x(i))12j=1J(1+log((σj(i))2)(μj(i))2(σj(i))2)+1Ll=1Llogpθ(x(i)z(i,l))\mathcal{L}(\boldsymbol{\theta}, \boldsymbol{\phi}; \mathbf{x}^{(i)}) \approx \frac{1}{2} \sum_{j=1}^{J} \left(1 + \log\left((\sigma_j^{(i)})^2\right) - (\mu_j^{(i)})^2 - (\sigma_j^{(i)})^2\right) + \frac{1}{L} \sum_{l=1}^{L} \log p_{\boldsymbol{\theta}}(\mathbf{x}^{(i)}|\mathbf{z}^{(i,l)})
where z(i,l)=μ(i)+σ(i)ϵ(l)andϵ(l)N(0,I)\text{where } \mathbf{z}^{(i,l)} = \boldsymbol{\mu}^{(i)} + \boldsymbol{\sigma}^{(i)} \odot \boldsymbol{\epsilon}^{(l)} \quad \text{and} \quad \boldsymbol{\epsilon}^{(l)} \sim \mathcal{N}(0, \mathbf{I})

Algorithm


간단하게 설명하면
1. XM=(xi)i=1MX^M =({x^i})^M_{i=1}의 미니배치 샘플링
2. 각 데이터 포인트에 대해(  i  ~~i~~)

  • 노이즈 샘플링 ϵ(i)N(0,I)\epsilon^{(i)} \sim \mathcal{N}(0, I)
  • 인코더 forward pass μ(i),logσ2(i)=encoderϕ(x(i))\mu^{(i)}, \log \sigma^{2(i)} = \text{encoder}_\phi(x^{(i)})
  • 재파라미터화 z(i)=μ(i)+σ(i)ϵ(i)z^{(i)} = \mu^{(i)} + \sigma^{(i)} \odot \epsilon^{(i)} , where σ(i)=exp(12logσ2(i))\text{where } \sigma^{(i)} = \exp\left(\frac{1}{2} \log \sigma^{2(i)}\right)
  • 디코더 forward pass x^(i)=decoderθ(z(i))\hat{x}^{(i)} = \text{decoder}_\theta(z^{(i)})

3.loss 계산

  • 재구성오차 :  Lrecon(i)=logpθ(x(i)z(i))~~\mathcal{L}_{\text{recon}}^{(i)} = -\log p_\theta(x^{(i)} | z^{(i)})
  • DKL lossD_{KL}~loss : LKL(i)=12j=1J((μj(i))2+(σj(i))2log(σj(i))21)\mathcal{L}_{KL}^{(i)} = \frac{1}{2} \sum_{j=1}^{J} \left( (\mu_j^{(i)})^2 + (\sigma_j^{(i)})^2 - \log(\sigma_j^{(i)})^2 - 1 \right)
  • 전체 손실: L~(i)=Lrecon(i)+LKL(i)\widetilde{\mathcal{L}}^{(i)} = \mathcal{L}_{\text{recon}}^{(i)} + \mathcal{L}_{KL}^{(i)}
  • 미니배치 손실:L~(i)=Lrecon(i)+LKL(i)\widetilde{\mathcal{L}}^{(i)} = \mathcal{L}_{\text{recon}}^{(i)} + \mathcal{L}_{KL}^{(i)}
  • Update: Update:~ θθαθL~M,      ϕϕαϕL~M\theta \leftarrow \theta - \alpha \nabla_\theta \widetilde{\mathcal{L}}^M , ~~~~~~\phi \leftarrow \phi - \alpha \nabla_\phi \widetilde{\mathcal{L}}^M

0개의 댓글