[딥러닝] Reparameterization Trick

Ethan·2023년 3월 10일
0

딥러닝 이론

목록 보기
6/6

본 블로그의 모든 글은 직접 공부하고 남기는 기록입니다.
잘못된 내용/오류 지적이나 추가 의견은 댓글로 꼭 남겨주시면 감사하겠습니다.


Reparameterization Trick

Reparameterization Trick은 주로 sampling 연산을 미분할 수 없어서 backprop을 사용하지 못하는 문제를 해결하기 위해 사용합니다. sampling 과정을 바로 미분할 수 없으니, sampling 연산 과정의 파라미터를 바꿔서 미분이 가능하도록 변환하는 기법입니다.

설명만 보면 간단해 보이지만, Varitional Inference를 비롯하여 그 맥락을 함께 살펴보면 꽤 복잡합니다. 차근차근 살펴보도록 하겠습니다.

Objective

임의의 모델을 fθ(x)f_\theta(x), input data를 xpθ(x)x\sim p_\theta(x)라고 가정해 보겠습니다. 이 때 우리는 모델의 파라미터 θ\theta를 최적화하고 싶습니다. 이 모델을 최적화하는데 필요한 loss function은 다음과 같이 표현할 수 있습니다.

L(θ)=Expθ(x)[fθ(x)](1)L(\theta)=E_{x\sim p_\theta(x)}[f_\theta(x)]\qquad (1)

식 (1)의 기울기를 구해보면 다음과 같습니다.

θL(θ)=θExpθ(x)[fθ(x)]=Expθ(x)[θfθ(x)](2)\begin{aligned} \nabla_\theta L(\theta)&=\nabla_\theta E_{x\sim p_\theta(x)}[f_\theta(x)]\\ \quad\\ &=E_{x\sim p_\theta(x)}[\nabla_\theta f_\theta(x)]\qquad (2) \end{aligned}

식 (2)는 별로 어렵지 않습니다. 만약 NN개의 데이터가 있다면, 아래와 같이 몬테카를로 방법으로 풀 수 있습니다.

θL(θ)1NnNθfθ(xn)\nabla_\theta L(\theta)\simeq{1\over|N|}\sum_n^N\nabla_\theta f_\theta(x_n)

여기까지가 일반적인 딥러닝 모델의 학습 과정입니다. 그런데 만약 다음과 같이 θ\theta 대신 ϕ\phi를 파라미터로 하는 확률분포 qq에서 출력한 zz를 입력으로 받는다면 어떨까요?

zqϕ(z)=pθ(zx)z\sim q_\phi(z)=p_\theta(z|x)

그러면 이 모델은 다음과 같이 쓸 수 있습니다.

fθ(z)=fθ(qϕ(z))f_\theta(z)=f_\theta(q_\phi(z))

이 모델의 cost는 다음과 같습니다.

L(θ;ϕ)=Expθ[fθ(qϕ(z))](3)L(\theta;\phi)=E_{x\sim p_\theta}[f_\theta(q_\phi(z))]\qquad (3)

이제 식 (3)을 가지고 backprop을 해야 하는데, 파라미터가 두 개로 늘었기 때문에 풀어야 하는 식도 다음과 같이 2개로 늘어나게 됩니다.

θL(θ;ϕ), ϕL(θ;ϕ)\nabla_\theta L(\theta;\phi),\ \nabla_\phi L(\theta;\phi)

θL(θ;ϕ)=θExpθ[fθ(pθ(zx))]\nabla_\theta L(\theta;\phi)=\nabla_\theta E_{x\sim p_\theta}[f_\theta(p_\theta(z|x))]이므로, 늘 하던 대로 쉽게 풀 수 있습니다. VAE에서 디코더에 해당하는 부분을 최적화하는 과정입니다. 수식을 보면 zz를 받아서 원본 xx를 복원하도록 되어 있죠.

문제는 인코더에 해당하는 ϕL(θ;ϕ)=ϕExpθ[fθ(qϕ(z))]\nabla_\phi L(\theta;\phi)=\nabla_\phi E_{x\sim p_\theta}[f_\theta(q_\phi(z))]입니다. 이 식을 전개해보면 다음과 같습니다.

ϕL(θ;ϕ)=ϕExpθ[fθ(qϕ(z))]=ϕxfθ(x)qϕ(z)dx=xfθ(x)ϕqϕ(z)dx(4)\begin{aligned} \nabla_\phi L(\theta;\phi)&=\nabla_\phi E_{x\sim p_\theta}[f_\theta(q_\phi(z))]\\ \quad\\ &=\nabla_\phi\int_xf_\theta(x)q_\phi(z)dx\\ \quad\\ &=\int_xf_\theta(x)\nabla_\phi q_\phi(z)dx\qquad(4) \end{aligned}

식 (4)는 서로 다른 파라미터에 대한 식이기 때문에 sampling만으로 해결이 안 됩니다.

Score Function Estimator

그래서 다음과 같은 간단한 트릭을 사용합니다.

ϕlogqϕ(z)=ϕqϕ(z)qϕ(z)\nabla_\phi\log q_\phi(z)={\nabla_\phi q_\phi(z)\over q_\phi(z)}

그러면 아래와 같이 식 (4)를 바꿔 쓸 수 있습니다.

ϕL(θ;ϕ)=xfθ(x)ϕqϕ(z)dx=xfθ(x)qϕ(z)ϕlogqϕ(z)dx=Expθ(x)[fθϕlogqϕ(z)]\begin{aligned} \nabla_\phi L(\theta;\phi)&=\int_xf_\theta(x)\nabla_\phi q_\phi(z)dx\\ \quad\\ &=\int_xf_\theta(x)q_\phi(z)\nabla_\phi\log q_\phi(z)dx\\ \quad\\ &=E_{x\sim p_\theta(x)}[f_\theta\nabla_\phi\log q_\phi(z)] \end{aligned}

마찬가지로 몬테카를로 방법을 적용하여 다음과 같이 정리할 수 있습니다.

ϕL(θ;ϕ)1NnNfθ(xn)ϕlogqϕ(zn)(5)\nabla_\phi L(\theta;\phi)\simeq{1\over |N|}\sum_n^Nf_\theta(x_n)\nabla_\phi \log q_\phi(z_n)\qquad (5)

이로써 주어진 데이터를 가지고 파라미터 θ\thetaϕ\phi에 대해 모두 최적화할 수 있게 되었습니다.

Reparameterization Trick

그런데 이러한 Reparameterization Trick이 잘 작동하려면 2가지 조건이 있습니다.

첫 번째로, qϕ(z)q_\phi(z)ϕ\phi에 대해 미분가능해야 합니다. 당연한 이야기죠? 우리는 pθ(zx)p_\theta(z|x)를 미분할 수 없어서 다른 미분 가능한 함수를 찾고 있는 거니까요.

두 번째로, qϕ(z)q_\phi(z)에 대한 정보가 필요합니다. VAE의 경우, prior를 정규분포로 가정합니다.

Expectation Maximization

위 2가지 조건을 만족한다는 전제 하에, Expectation Maximization (EM) algorithm을 사용해서 각각의 파라미터를 최적화할 수 있습니다. 구체적인 과정은 다음과 같습니다.

(1) DKL(qϕ(z)pθ(zx))D_{KL}(q_\phi(z)||p_\theta(z|x))를 이용하여 ELBO를 구합니다.

(2) 구한 ELBO를 최대화합니다.

Step 1. KLD를 이용한 ELBO 유도

DKL(qϕ(z)pθ(zx))=qϕ(z)logqϕ(z)pθ(zx)dzdx=qϕ(z)logqϕ(z)pθ(x)pθ(xz)pθ(z)dzdx=qϕ(z)logqϕ(z)pθ(z)dzdx+qϕ(z)logpθ(x)dzdxqϕ(z)logpθ(xz)dzdx=DKL(qϕ(z)pθ(z))+logpθ(x)Ezpθ(z)[logpθ(xz)](6)\begin{aligned} D_{KL}(q_\phi (z)||p_\theta(z|x))&=\int q_\phi(z)\log{q_\phi(z)\over p_\theta(z|x)}dzdx\\ \quad\\ &=\int q_\phi(z)\log{q_\phi(z)p_\theta(x)\over p_\theta(x|z)p_\theta(z)}dzdx\\ \quad\\ &=\int q_\phi(z)\log{q_\phi(z)\over p_\theta(z)}dzdx+\int q_\phi(z)\log p_\theta(x)dzdx-\int q_\phi(z)\log p_\theta(x|z)dzdx\\ \quad\\ &=D_{KL}(q_\phi(z)||p_\theta(z))+\log p_\theta(x)-E_{z\sim p_\theta(z)}[\log p_\theta(x|z)]\qquad (6) \end{aligned}

식 (6)을 logpθ(x)\log p_\theta(x)에 대해 정리하면 다음과 같습니다.

logpθ(x)=Ezpθ(z)[logpθ(xz)]DKL(qϕ(z)pθ(z))+DKL(qϕ(z)pθ(zx))\log p_\theta(x) = E_{z\sim p_\theta(z)}[\log p_\theta(x|z)]-D_{KL}(q_\phi(z)||p_\theta(z))+D_{KL}(q_\phi(z)||p_\theta(z|x))

KL Divergence 값은 항상 양수이므로,

DKL(qϕ(z)pθ(zx))0D_{KL}(q_\phi(z)||p_\theta(z|x)) \geq 0

따라서 다음과 같이 ELBO 형태로 정리됩니다.

logpθ(x)Ezpθ(z)[logpθ(xz)]DKL(qϕ(z)pθ(z))(7)\log p_\theta(x) \geq E_{z\sim p_\theta(z)}[\log p_\theta(x|z)]-D_{KL}(q_\phi(z)||p_\theta(z))\qquad (7)

Step 2. ELBO 최대화

식 (7)의 우변을 최대화하려면 DKL(qϕ(z)pθ(z))D_{KL}(q_\phi(z)||p_\theta(z))가 최소화되어야 합니다. 그런데 우리는 앞서 qϕ(z)q_\phi(z)를 정규분포로 가정했습니다. 따라서 pθ(z)p_\theta(z) 또한 정규분포여야 합니다. 즉, 두 정규분포의 차이를 최소화하는 문제가 됩니다.

pθ(zx)=qϕ(z)pθ(z)N(μ, σ2)(8)p_\theta(z|x)=q_\phi(z)\simeq p_\theta(z)\sim N(\mu,\ \sigma^2)\qquad(8)

식 (8)이 의미하는 바는 xx에서 샘플링한 zz를 정규분포의 평균과 분산의 합으로 표현할 수 있다는 것입니다. 다시 말해, xx에서 zz를 샘플링하는 것과 모델이 근사한 정규분포 NN을 따르는 랜덤한 샘플을 뽑는 것이 같다는 뜻이 됩니다!

따라서 데이터 샘플 nn에 대해 다음이 성립합니다.

zn=μ+σn×ϵn,ϵN(0,1)z_n=\mu+\sigma_n\times\epsilon_n,\quad\epsilon\sim N(0,1)

참고문헌

  1. Gumbel Distribution (Wiki)
  2. Jaejun Yoo's Playground - Categorical Reparameterization with Gumbel Softmax
  3. Jaejun Yoo님의 PR12 발표 QnA
  4. AAA (All About AI) - Gumbel-Softmax Trick
  5. Kaen's Ritus - Gumbel-Softmax 리뷰
  6. Leonard Tang - The Gumbel-Max Trick: Explained
  7. Towards open set deep networks
  8. Hulk - Reparametrization Trick 정리
  9. 데이터 사이언스 스쿨 - 기댓값과 확률변수의 변환
  10. 두 표본평균 차이의 표본분포
  11. KL-divergence with Gaussian distribution 증명
  12. Gumbel Softmax 설명
  13. Yarin Gal - bayesian deep learning
profile
재미있게 살고 싶은 대학원생

2개의 댓글

comment-user-thumbnail
2023년 4월 17일

안녕하세요. 좋은 포스팅 감사합니다.

Reparameterization Trick부분에서 4번 수식의 첫 번째 줄 => 두 번째 줄로 가는 과정과,

Score Function Estimator의 중간에 있는 수식 3줄 부분에서, 두 번째 줄 => 세 번째 줄로 가는 과정에 대해서 구체적으로 설명해주실 수 있나요?

1개의 답글