VAE: Auto-Encoding Variational Bayes

김태규·2025년 1월 18일

생성모델

목록 보기
2/3

Introduction

오늘 리뷰할 논문은 VAE(variation auto-encoding)이다.
Auto-Encoder가 기존의 데이터를 효과적으로 압축시키고 복원시키는 것을 목표로 하는 것에 반해 VAE는 원본의 데이터를 사용해서 원본과 약간은 다른 데이터를 생성하는 것을 목표로 한다.

이 포스팅은 VAE의 전체적인 구조와 작동방식을 직관적으로 이해하는 것을 목표로 한다.


구조설명

아래는 VAE의 구조이다.

VAE의 목표는 Generative model의 목표와 같다.
input xix_i의 분포와 비슷한 latent variable ziz_i를 통해 xix_i와 다른 새로운 pip_i를 생성하는 것이 목표이다.

내가 환타를 모델어 집어넣으면 탄산음료지만 환타와는 다른 사이다가 나오길 바라는 것이다.

좀 더 자세하게 살펴보자

구조는 크게 Gaussian Encoder, Reparameterization Trick(Sampling), Decoder 로 나눌 수 있다.

  • Gaussian Encoderxix_i의 정보를 압축하는 부분이라고 생각하면 된다.

  • Reparameterization Trick(Sampling)Gaussian Encoder를 거쳐온 정보들에 Noise를 추가하는 단계라고 보면 된다. ( 미분이 가능하도록 수식을 변경하는 부분이기도 함)

  • Decoder는 샘플링 된 ziz_i를 디코딩 하는 하여 원본과 비슷한 이미지를 만드는 단계이다.


Loss

그렇다면 우리가 이 시스템을 만들기 위해 해야 될 작업은 무엇일까?

MLE(Maximum Likelihood Estimation)
바로 logPθ(xi)\log P_{\theta}(x_i)를 최대화 하는 것이다.

Pθ(xi)P_{\theta}(x_i): xix_i를 넣었을때 나오는 모델의 output

xx를 넣었을때 xx가 나올 확률을 Maximize해주는 분포를 만들어주는 모델의 파라미터 θ\theta를 찾아야 한다. (분포를 찾는 것과 동일)

모든 식은 이 logPθ(xi)\log P_{\theta}(x_i)에서 나온다.

수식의 전개는 다음과 같다.

logpθ(x(i))=Ezqϕ(zx(i))[logpθ(x(i))]log p_\theta(x^{(i)}) = \mathbb{E}_{z \sim q_\phi(z|x^{(i)})} \left[ \log p_\theta(x^{(i)}) \right]

=Ez[logpθ(x(i)z)pθ(z)pθ(zx(i))](Bayes’ Rule)= \mathbb{E}_z \left[ \log \frac{p_\theta(x^{(i)}|z)p_\theta(z)}{p_\theta(z|x^{(i)})} \right] \quad \text{(Bayes' Rule)}

=Ez[logpθ(x(i)z)pθ(z)qϕ(zx(i))qϕ(zx(i))pθ(zx(i))]= \mathbb{E}_z \left[ \log \frac{p_\theta(x^{(i)}|z)p_\theta(z)}{q_\phi(z|x^{(i)})} \cdot \frac{q_\phi(z|x^{(i)})}{p_\theta(z|x^{(i)})} \right]

=Ez[logpθ(x(i)z)]Ez[logqϕ(zx(i))pθ(z)]+Ez[logqϕ(zx(i))pθ(zx(i))]= \mathbb{E}_z \left[ \log p_\theta(x^{(i)}|z) \right] - \mathbb{E}_z \left[ \log \frac{q_\phi(z|x^{(i)})}{p_\theta(z)} \right] + \mathbb{E}_z \left[ \log \frac{q_\phi(z|x^{(i)})}{p_\theta(z|x^{(i)})} \right]

=Ez[logpθ(x(i)z)]DKL(qϕ(zx(i))pθ(z))+DKL(qϕ(zx(i))pθ(zx(i)))= \mathbb{E}_z \left[ \log p_\theta(x^{(i)}|z) \right] - D_\text{KL} \left( q_\phi(z|x^{(i)}) \parallel p_\theta(z) \right) + D_\text{KL} \left( q_\phi(z|x^{(i)}) \parallel p_\theta(z|x^{(i)}) \right)

수식의 첫 번째 줄은 z에 대한 수식이 없기 때문에 단순히 상수에 z에 대한 평균을 취한 것과 동일하다고 생각하면 된다.

EX) E[C]=CE[C] = C 와 같음

두 번째 줄의 수식은 Bayes rule이 사용되었다.

세 번째 줄의 수식은 분자와 분모에 모두 qϕ(zx)q_{\phi}(z|x)를 곱해주었다.

네 번째 줄의 수식은 수식을 분리하였고

마지막 줄은 분리된 수식에 대해서 KL-Divergence로 변경을 해주었다.

KL-Divergence 변환 과정

마지막 수식을 살펴보자

Ez[logpθ(x(i)z)]DKL(qϕ(zx(i))pθ(z))+DKL(qϕ(zx(i))pθ(zx(i)))\mathbb{E}_z \left[ \log p_\theta(x^{(i)}|z) \right] - D_\text{KL} \left( q_\phi(z|x^{(i)}) \parallel p_\theta(z) \right) + D_\text{KL} \left( q_\phi(z|x^{(i)}) \parallel p_\theta(z|x^{(i)}) \right)

  • 첫 번째 항: Encoder로 부터 zz를 sampling하고 sampling된 zz로 부터 output을 뽑아낼 수 있다. (계산 가능함 + zz에서 xx가 나올 확률을 극대화함.)

  • 두 번째 항: Encoder를 통과한 qϕ(zx(i))q_\phi(z|x^{(i)})가 우리가 정의한 pθ(z)p_\theta(z)의 분포와 비슷해지도록 만든다. (계산이 가능함)

  • 세 번째 항: qϕ(zx(i))q_\phi(z|x^{(i)})는 계산이 가능하지만 pθ(zx(i))p_\theta(z|x^{(i)})xizx_i \rightarrow z의 정답분포는 구할 수 없다. (계산이 불가능함. 하지만 KL-Divergence가 항상 양수임은 알 수 있음.)

우리는 우리가 계산할 수 있는 부분만 살펴볼 수 있다.

세 번째 항은 알 수 없기 때문에 논문에서는 첫 번째와 두 번째항을 최대화시킨다.

계산할 수 있는 부분을 ELBO라고 한다.
ELBO는 항상 다음의 수식을 만족한다. (세 번째 항은 항상 양수이기 때문)

Ez[logpθ(x(i)z)]DKL(qϕ(zx(i))pθ(z))+DKL(qϕ(zx(i))pθ(zx(i)))\mathbb{E}_z \left[ \log p_\theta(x^{(i)}|z) \right] - D_\text{KL} \left( q_\phi(z|x^{(i)}) \parallel p_\theta(z) \right) + D_\text{KL} \left( q_\phi(z|x^{(i)}) \parallel p_\theta(z|x^{(i)}) \right)

>Ez[logpθ(x(i)z)]DKL(qϕ(zx(i))pθ(z))>\mathbb{E}_z \left[ \log p_\theta(x^{(i)}|z) \right] - D_\text{KL} \left( q_\phi(z|x^{(i)}) \parallel p_\theta(z) \right)

Reconstruction Error

Ez[logpθ(x(i)z)]\mathbb{E}_z \left[ \log p_\theta(x^{(i)}|z) \right]

우리는 이제 이 부분을 Reconstruction Error라고 부른다.
Reconstruction Error는 다음과 같이 표현이 가능하다.

Ez[logpθ(x(i)z)]=logpθ(xz)qϕ(zx)dz\mathbb{E}_z \left[ \log p_\theta(x^{(i)}|z) \right] = \int \log p_{\theta}(x|z)q_{\phi}(z|x)dz의 계산으로 나타낼 수 있다.

하지만 모든 구간의 zz에 대해 적분할 수 없기 때문에 논문에서는 Monte Carlo Method를 사용한다.

Ez[logpθ(x(i)z)]1Lzi,llogpθ(x(i)zi,l)\mathbb{E}_z \left[ \log p_\theta(x^{(i)}|z) \right] \approx \frac{1}{L} \sum_{z^{i, l}}\log p_{\theta}(x^{(i)}|z^{i, l})

이를 다음과 같이 표현가능하다.

근사가 잘 되기 위해서는 매우 큰 LL값을 선택해야 하는데 우리는 한번 sample을 뽑을 때마다 계산을 해야한다. 연산량이 커지는 것이다.

논문에서는 이를 막기위해 L=1L=1로 고정시켜 Random sample을 1개 뽑아 사용한다.

따라서 Reconstruction Error는 다음과 같게 된다.

logpθ(x(i)zi,l)\log p_{\theta}(x^{(i)}|z^{i, l})

이 수식을 계산하기 위해서 우리는 원본의 분포를 가정해야 한다.

논문에서는 두 가지 분포를 가정한다.

  1. 베르누이 분포

log(pθ(xizi))=logj=1Dpθ(xi,jzi)=j=1Dlogpθ(xi,jzi)\log \left( p_\theta(x^i | z^i) \right) = \log \prod_{j=1}^D p_\theta(x_{i,j} | z^i) = \sum_{j=1}^D \log p_\theta(x_{i,j} | z^i)

=j=1Dlog(pi,jxi,j(1pi,j)1xi,j)= \sum_{j=1}^D \log \left( p_{i,j}^{x_{i,j}} (1 - p_{i,j})^{1 - x_{i,j}} \right)

=j=1D(xi,jlogpi,j+(1xi,j)log(1pi,j))= \sum_{j=1}^D \left( x_{i,j} \log p_{i,j} + (1 - x_{i,j}) \log (1 - p_{i,j}) \right)

베르누이 분포를 따른다고 가정했을 때 다음과 같이 수식을 전개할 수 있다.
최종적으로 베르누이 분포를 따르는 Reconstruction ErrorCross Entropy Loss가 나온다.

  1. 가우시안 분포

log(pθ(xizi))=log(N(xi;μi,σi2I))\log \left( p_\theta(x^i | z^i) \right) = \log \left( \mathcal{N}(x_i; \mu_i, \sigma_i^2 I) \right)

=j=1D(12log(σi,j2)+(xi,jμi,j)22σi,j2)= -\sum_{j=1}^D \left( \frac{1}{2} \log (\sigma_{i,j}^2) + \frac{(x_{i,j} - \mu_{i,j})^2}{2 \sigma_{i,j}^2} \right)

For Gaussian distribution with identity covariance\text{For Gaussian distribution with identity covariance}

log(pθ(xizi))j=1D(xi,jμi,j)2Squared Error\log \left( p_\theta(x^i | z^i) \right) \propto -\sum_{j=1}^D (x_{i,j} - \mu_{i,j})^2 \quad \text{Squared Error}

가우시안 분포를 따른다고 가정했을 때 다음과 같이 수식을 전개할 수 있다.
최종적으로 가우시안 분포를 따르는 Reconstruction ErrorSquared Error가 나온다.

Reconstruction Error가 베르누이 분포를 따를 때 Decoder를 Bernoulli Decoder라고 하고 가우시안 분포를 따를 때 Decoder를 Gaussian Decoder라고 한다.

Regularization

우리는 수식의 두번째 항을 Regularization이라고 부른다.

DKL(qϕ(zx(i))pθ(z))D_\text{KL} \left( q_\phi(z|x^{(i)}) \parallel p_\theta(z) \right)

우리가 정의한 분포인 pθ(z)p_{\theta}(z)분포와 qϕ(zx(i))q_{\phi}(z|x^{(i)})와 비슷하게 만들기 때문이다.

Regularization이 없다고 가정해보자.

그렇다면 zz를 사용하여 원본 xx가 나올 확률만을 극대화 하도록 만들 것이다.

따라서 qϕ(zx(i))q_{\phi}(z|x^{(i)})분포를 어느정도 조정함으로써, VAE가 Reconstruction task만을 잘하는 것을 방지한다. 그렇기 때문에 Regularization 불리는 것이다

계산방식은 다음과 같다.

Reparameterization Trick

이 부분은 latent space zz에 대해서 Sampling을 하는 부분이다. 물론 그냥 Sampling을 하면 안된다. (미분이 불가능함)

단순히 확률 분포에서 Sample을 뽑는 것은 backpropagation이 불가능 하기 때문에 Normal distribution을 따르는 ϵ\epsilon을 사용하여 샘플링 한다. (오른쪽 방식)

Decoder의 output은 μ\mu(평균)와 σ\sigma(표준편차)가 나오는데 우리는 noise를 섞어줌과 동시에 미분이 가능하도록 오른쪽과 같은 수식으로 zz를 Sampling한다.

sampling을 하는 ϵ\epsilon은 backpropagation하지 않기 때문에 Decoder방향으로 미분이 가능함


코드 구현

import torch
from torch import nn, optim
import torch.utils.data
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, download=True,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

class VAE(nn.Module):
    def __init__(self, input_dim=784, h_dim=400, latent_dim=20):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(input_dim, h_dim)
        self.fc2_mu = nn.Linear(h_dim, latent_dim)
        self.fc2_var = nn.Linear(h_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, input_dim)

    def encode(self, x):
        out = F.relu(self.fc1(x))
        return self.fc2_mu(out), self.fc2_var(out)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, x):
        out = F.relu(self.fc3(x))
        return torch.sigmoid(self.fc4(out))

    def forward(self, x):
        mu, log_var = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    return BCE + KLD

def train(epoch):
    model.train()
    train_loss = 0
    for idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon, mu, log_var = model(data)
        loss = loss_function(recon, data, mu, log_var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for idx, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon, mu, log_var = model(data)
            test_loss += loss_function(recon, data, mu, log_var).item()
            if idx == 0:
                n = min(data.size(0), 10)
                comparison = torch.cat([data[:n],
                                        recon.view(-1, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                           './results/epoch_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

if __name__ == "__main__":
    s = torch.randn(64, 20).to(device)
    for epoch in range(0, 101):
        train(epoch)
        test(epoch)
        if epoch % 10 == 0:
            sample = model.decode(s).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       'results/sample_{}_result.png'.format(epoch))

실행 결과

profile
발전하는 개발자입니다!

0개의 댓글