DDPM 구현

pyross·2024년 11월 22일
0

이전 논문 리뷰 후 구현을 해보았다.

전체 코드는 깃허브 링크를 보면 된다.

  • 실제로 MNIST를 학습하고 생성한 이미지

구현


우선 전체 학습 구현은 위의 알고리즘이 전부이다.

model은 생략하겠다. model에 대한 논문이 아니기 때문
model은 U-Net에 time embedding을 활용한 것을 가져와서 base로 사용하였다.

우선 학습을 위해서 필요한 beta, alpha를 미리 정의해둔다.

T = 1000
beta_start = 1e-4
beta_end = 0.02

betas = torch.linspace(beta_start, beta_end, T).to(device)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).to(device)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod).to(device)

training


위 의사코드를 보고 이해하면 매우 쉽다.

for epoch in tqdm(range(epochs)):
    total_loss = 0
    for i, (x, _) in enumerate(train_loader):
        x = x.to(device)
        batch_size = x.size(0)

        t = torch.randint(0, T, (batch_size,)).to(device)

        noise = torch.randn_like(x).to(device)

        # 이미지에 노이즈 추가
        sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t].reshape(x.shape[0], 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].reshape(
            x.shape[0], 1, 1, 1)
        # reparametrization trick
        noise_x = sqrt_alphas_cumprod_t * x + sqrt_one_minus_alphas_cumprod_t * noise

        noise_pred = model(noise_x, t)

        noise_loss = criterion(noise_pred, noise)
        optimizer.zero_grad()
        noise_loss.backward()
        optimizer.step()
        total_loss += noise_loss.item()

    print(f"Epoch [{epoch}], Loss: {total_loss/len(train_loader)}")

설명하자면 임의의 time t를 뽑고 xt=αˉtx0+1αtˉϵ,ϵN(0,I)x_t=\sqrt{\bar \alpha_t}x_0 +\sqrt{1-\bar{\alpha_t}}\epsilon, \quad \epsilon \sim \mathcal{N}(0,I)로 빠르게 만든다.
이를 model에 넣어서 noise를 예측하고 실제 noise와 mseloss를 이용해서 학습한다.

사실 이게 끝이다. 매우 간단하다.

sampling

역시 위 의사코드가 전부이다.

@torch.no_grad()
def sample(model, img_size, alphas, alphas_cumprod, betas, T, device, batch_size=64):
    model.eval()
    x = torch.randn(batch_size, 1, img_size, img_size).to(device)  # 순수 노이즈로부터 시작
    noise_to_x = [x]
    for t in reversed(range(T)):
        # 현재 타임스텝 t
        t_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)

        # 노이즈 예측합니다.
        noise_pred = model(x, t_tensor)
        beta_t = betas[t].to(device)

        # 이전 x를 계산합니다.
        if t > 0:
            noise = torch.randn_like(x)
        else:
            noise = torch.zeros_like(x)

        x = (1 / torch.sqrt(alphas[t])) * (x - ((1 - alphas[t]) / torch.sqrt(1 - alphas_cumprod[t])) * noise_pred) + torch.sqrt(beta_t) * noise
    return x

위처럼 구현이 되는데 처음 random noise

x = torch.randn(batch_size, 1, img_size, img_size).to(device)

에서 시작해서 T를 reverse로 복구한다.
model에 현재의 noise된 image를 넣어서 noise를 예측하고
기존에 유도된 수식에 따라서

x = (1 / torch.sqrt(alphas[t])) * (x - ((1 - alphas[t]) / torch.sqrt(1 - alphas_cumprod[t])) * noise_pred) + torch.sqrt(beta_t) * noise

다음 x를 이렇게 만들어낸다.
사실 sampling도 이게 전부이다.

의외로 매우 간단했다.

0개의 댓글