참고1: Annotated Diffusion
참고2: Diffusion Model 과 DDPM 수식 유도 과정
Diffusion이 풀고자 하는 문제는
forward process와 reverse process를 거친 'output 이미지 '를 샘플링하는 확률분포 를
'input 이미지 '를 샘플링하는 확률 분포 와 유사하게 만드는 것이다.
직관적으로 생각해보면,
Forward process를 거치면서
input 이미지 에 노이즈를 time step 만큼 추가해 pure Gaussian noise 를 얻은 후에
다시 Reverse process를 거치면서
를 이용해 Gaussian noise 로부터 를 샘플링하면 될 것이다.
하지만, 우리는 확률 분포 를 알지 못한다.
그래서 우리는 신경망을 이용해서, 확률분포 를 가장 잘 모델링하는, 확률분포 를 예측하고 학습한다.
즉, 우리가 예측하는 확률분포 에서 가 샘플링될 확률이 가장 높게
= Likelihood 가 가장 크도록
= 확률분포 가 에 가까워지도록
신경망을 훈련하면 되겠다!
를 최대화 하는 parameter 는
를 최대화 하는 parameter 와 동일하다. 수식 전개의 편리함을 위해 를 사용한다.
를 최대화 하는 수식을 전개하다보면,
과 간의 KL-Divergence를 최소화하는 식으로 전개된다.
이는 두 분포 간의 평균의 차이를 최소화하는 식으로 전개되며,
위 식을 수학적인 트릭(Reparameterization trick)을 사용하면 (=수식적으로 전개하다보면)
결국 time step 시점에 실제로 추가된 노이즈 과 신경망이 예측한 시점의 노이즈 간의 차이를 최소화하는 식으로 전개된다.
정리하면, 확률분포 가 에 가까워지도록 신경망을 학습시키기 위해선
확률분포 와 확률분포 간의 loss 함수를 정의해야 한다.
Diffusion에서는 확률분포 와 확률분포 가 정규 분포(gaussian distribution)을 따른다고 설정한다.
정규분포는 두 가지 정보, 평균()과 분산()으로 표현된다.
그렇다면, "확률분포 의 평균(), 분산"과
"확률 분포 의 평균(), 분산"의 차이를 줄이는 loss 함수를 만들면 되겠다.
DDPM에서는 분산은 고정된(fixed) 값으로 두고 평균의 차이를 줄이는 방식을 사용했다.
이후에 분산까지 학습하는 연구가 나오긴 했으나, 지금은 DDPM을 기준으로 생각하자.
수학적인 트릭(Reparameterization trick)을 사용하면
(=수식적으로 전개하다보면)
"확률분포 의 평균()"과, "확률 분포 의 평균()"의 차이를 줄이는 loss function 수식을
"이미지에 실제로 추가된 노이즈()"과 "신경망이 예측한 노이즈()"의 차이를 줄이는 loss function 수식으로 전개가 가능하다.
따라서, 신경망은 이미지에 추가된 노이즈를 예측하도록 설계한다.
가장 간단한 버전으로 Diffusion을 구현한 코드를 요약하여, 핵심적인 구현 흐름을 확인해보자.
2-1) variance schedule: scheduling
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
2-2) noising (at time step )
def q_sample(x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqert_one_minus_alphas_cumprod, t, x_start.shape
)
return sqrt_alphas_cumpord_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
2-3) Loss Function을 정의한다.
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
if noise is None:
noise = torch.randn_like(x_start)
# noised image
x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
# predicted noise. 여기서 denoise_model은 U-Net
predicted_noise = denoise_model(x_noisy, t)
# l1_loss를 사용할 경우
loss = F.l1_loss(noise, predicted_noise)
return loss
1) time step 의 pure noise에서 Gaussian distribution을 sampling한다.
2) denoising
time step T부터 0까지, 신경망을 사용해서, 신경망이 학습한 conditional probability를 이용해 점진적으로 denoise한다.
mean을reparameterization하여 만든 우리의 noise predictor를 이용해서,
우리는 조금 더 denoised 된 image 을 얻을 수 있다.
이 과정을 통해 real data distribution 에서 생성된 것과 유사한 새로운 image를 얻게된다.
@torch.no_grad()
def p_sample(model, x, t, t_index):
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
# 1/sqrt(\alpha_t)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
# Equation 11 in the paper
# Use our model (noise predictor) to predict the mean
model_mean = sqrt_recip_alphas_t * (
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
)
if t_index == 0:
return model_mean
else:
posterior_variance_t = extract(
posterior_variance, t, x.shape
)
noise = torch.randn_like(x)
# Algorithm2 line 4:
return model_mean + torch.sqrt(posterior_variance_t) * noise
epochs = 6
for epoch in range(epochs):
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
# Algorithm 1 line 3: sample t uniformally for every example in the batch
t = torch.randint(0, timesteps, (batch_size,), device=device).long()
loss = p_losses(denoise_model=model, x_start=batch, t=t, loss_type="huber")
if step % 100 == 0:
print("Loss:", loss.item())
loss.backward()
optimizer.step()
# Algorithm 2 (includint returning all images)
@torch.no_grad()
def p_sample_loop(model, shape):
device = next(model.parameters()).device
b = shape[0]
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc = 'sampling loop time step', total=timesteps):
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
imgs.append(img.cpu().numpy())
return imgs
@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
결과 시각화
# show a random one
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")