본 블로그의 모든 글은 직접 공부하고 남기는 기록입니다.
잘못된 내용/오류 지적이나 추가 의견은 댓글로 꼭 남겨주시면 감사하겠습니다.
Reparameterization Trick
Reparameterization Trick은 주로 sampling 연산을 미분할 수 없어서 backprop을 사용하지 못하는 문제를 해결하기 위해 사용합니다. sampling 과정을 바로 미분할 수 없으니, sampling 연산 과정의 파라미터를 바꿔서 미분이 가능하도록 변환하는 기법입니다.
설명만 보면 간단해 보이지만, Varitional Inference를 비롯하여 그 맥락을 함께 살펴보면 꽤 복잡합니다. 차근차근 살펴보도록 하겠습니다.
Objective
임의의 모델을 fθ(x), input data를 x∼pθ(x)라고 가정해 보겠습니다. 이 때 우리는 모델의 파라미터 θ를 최적화하고 싶습니다. 이 모델을 최적화하는데 필요한 loss function은 다음과 같이 표현할 수 있습니다.
L(θ)=Ex∼pθ(x)[fθ(x)](1)
식 (1)의 기울기를 구해보면 다음과 같습니다.
∇θL(θ)=∇θEx∼pθ(x)[fθ(x)]=Ex∼pθ(x)[∇θfθ(x)](2)
식 (2)는 별로 어렵지 않습니다. 만약 N개의 데이터가 있다면, 아래와 같이 몬테카를로 방법으로 풀 수 있습니다.
∇θL(θ)≃∣N∣1n∑N∇θfθ(xn)
여기까지가 일반적인 딥러닝 모델의 학습 과정입니다. 그런데 만약 다음과 같이 θ 대신 ϕ를 파라미터로 하는 확률분포 q에서 출력한 z를 입력으로 받는다면 어떨까요?
z∼qϕ(z)=pθ(z∣x)
그러면 이 모델은 다음과 같이 쓸 수 있습니다.
fθ(z)=fθ(qϕ(z))
이 모델의 cost는 다음과 같습니다.
L(θ;ϕ)=Ex∼pθ[fθ(qϕ(z))](3)
이제 식 (3)을 가지고 backprop을 해야 하는데, 파라미터가 두 개로 늘었기 때문에 풀어야 하는 식도 다음과 같이 2개로 늘어나게 됩니다.
∇θL(θ;ϕ), ∇ϕL(θ;ϕ)
∇θL(θ;ϕ)=∇θEx∼pθ[fθ(pθ(z∣x))]이므로, 늘 하던 대로 쉽게 풀 수 있습니다. VAE에서 디코더에 해당하는 부분을 최적화하는 과정입니다. 수식을 보면 z를 받아서 원본 x를 복원하도록 되어 있죠.
문제는 인코더에 해당하는 ∇ϕL(θ;ϕ)=∇ϕEx∼pθ[fθ(qϕ(z))]입니다. 이 식을 전개해보면 다음과 같습니다.
∇ϕL(θ;ϕ)=∇ϕEx∼pθ[fθ(qϕ(z))]=∇ϕ∫xfθ(x)qϕ(z)dx=∫xfθ(x)∇ϕqϕ(z)dx(4)
식 (4)는 서로 다른 파라미터에 대한 식이기 때문에 sampling만으로 해결이 안 됩니다.
Score Function Estimator
그래서 다음과 같은 간단한 트릭을 사용합니다.
∇ϕlogqϕ(z)=qϕ(z)∇ϕqϕ(z)
그러면 아래와 같이 식 (4)를 바꿔 쓸 수 있습니다.
∇ϕL(θ;ϕ)=∫xfθ(x)∇ϕqϕ(z)dx=∫xfθ(x)qϕ(z)∇ϕlogqϕ(z)dx=Ex∼pθ(x)[fθ∇ϕlogqϕ(z)]
마찬가지로 몬테카를로 방법을 적용하여 다음과 같이 정리할 수 있습니다.
∇ϕL(θ;ϕ)≃∣N∣1n∑Nfθ(xn)∇ϕlogqϕ(zn)(5)
이로써 주어진 데이터를 가지고 파라미터 θ와 ϕ에 대해 모두 최적화할 수 있게 되었습니다.
Reparameterization Trick
그런데 이러한 Reparameterization Trick이 잘 작동하려면 2가지 조건이 있습니다.
첫 번째로, qϕ(z)가 ϕ에 대해 미분가능해야 합니다. 당연한 이야기죠? 우리는 pθ(z∣x)를 미분할 수 없어서 다른 미분 가능한 함수를 찾고 있는 거니까요.
두 번째로, qϕ(z)에 대한 정보가 필요합니다. VAE의 경우, prior를 정규분포로 가정합니다.
Expectation Maximization
위 2가지 조건을 만족한다는 전제 하에, Expectation Maximization (EM) algorithm을 사용해서 각각의 파라미터를 최적화할 수 있습니다. 구체적인 과정은 다음과 같습니다.
(1) DKL(qϕ(z)∣∣pθ(z∣x))를 이용하여 ELBO를 구합니다.
(2) 구한 ELBO를 최대화합니다.
Step 1. KLD를 이용한 ELBO 유도
DKL(qϕ(z)∣∣pθ(z∣x))=∫qϕ(z)logpθ(z∣x)qϕ(z)dzdx=∫qϕ(z)logpθ(x∣z)pθ(z)qϕ(z)pθ(x)dzdx=∫qϕ(z)logpθ(z)qϕ(z)dzdx+∫qϕ(z)logpθ(x)dzdx−∫qϕ(z)logpθ(x∣z)dzdx=DKL(qϕ(z)∣∣pθ(z))+logpθ(x)−Ez∼pθ(z)[logpθ(x∣z)](6)
식 (6)을 logpθ(x)에 대해 정리하면 다음과 같습니다.
logpθ(x)=Ez∼pθ(z)[logpθ(x∣z)]−DKL(qϕ(z)∣∣pθ(z))+DKL(qϕ(z)∣∣pθ(z∣x))
KL Divergence 값은 항상 양수이므로,
DKL(qϕ(z)∣∣pθ(z∣x))≥0
따라서 다음과 같이 ELBO 형태로 정리됩니다.
logpθ(x)≥Ez∼pθ(z)[logpθ(x∣z)]−DKL(qϕ(z)∣∣pθ(z))(7)
Step 2. ELBO 최대화
식 (7)의 우변을 최대화하려면 DKL(qϕ(z)∣∣pθ(z))가 최소화되어야 합니다. 그런데 우리는 앞서 qϕ(z)를 정규분포로 가정했습니다. 따라서 pθ(z) 또한 정규분포여야 합니다. 즉, 두 정규분포의 차이를 최소화하는 문제가 됩니다.
pθ(z∣x)=qϕ(z)≃pθ(z)∼N(μ, σ2)(8)
식 (8)이 의미하는 바는 x에서 샘플링한 z를 정규분포의 평균과 분산의 합으로 표현할 수 있다는 것입니다. 다시 말해, x에서 z를 샘플링하는 것과 모델이 근사한 정규분포 N을 따르는 랜덤한 샘플을 뽑는 것이 같다는 뜻이 됩니다!
따라서 데이터 샘플 n에 대해 다음이 성립합니다.
zn=μ+σn×ϵn,ϵ∼N(0,1)
참고문헌
- Gumbel Distribution (Wiki)
- Jaejun Yoo's Playground - Categorical Reparameterization with Gumbel Softmax
- Jaejun Yoo님의 PR12 발표 QnA
- AAA (All About AI) - Gumbel-Softmax Trick
- Kaen's Ritus - Gumbel-Softmax 리뷰
- Leonard Tang - The Gumbel-Max Trick: Explained
- Towards open set deep networks
- Hulk - Reparametrization Trick 정리
- 데이터 사이언스 스쿨 - 기댓값과 확률변수의 변환
- 두 표본평균 차이의 표본분포
- KL-divergence with Gaussian distribution 증명
- Gumbel Softmax 설명
- Yarin Gal - bayesian deep learning
안녕하세요. 좋은 포스팅 감사합니다.
Reparameterization Trick부분에서 4번 수식의 첫 번째 줄 => 두 번째 줄로 가는 과정과,
Score Function Estimator의 중간에 있는 수식 3줄 부분에서, 두 번째 줄 => 세 번째 줄로 가는 과정에 대해서 구체적으로 설명해주실 수 있나요?