[Annotated Diffusion] DDPM-(4) 코드 큰 흐름으로 정리하기

YEOM JINSEOP·2023년 12월 17일
0

Generative Model

목록 보기
3/4
post-custom-banner

출처: Annotated Diffusion

가장 간단한 버전으로 구현 요약

1. 신경망으로 U-Net을 사용한다.

2. Forward Process를 정의한다.

2-1) variance schedule: βt\beta_{t} scheduling

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

2-2) noising (at time step t\bold{t})
ϵϵθ(xt,t)2=ϵϵθ(αtˉx0+1αtˉϵ,t)2||\epsilon-\epsilon_{\theta}(\bold{x_t, t})||^2=||\epsilon-\epsilon_{\theta}(\sqrt{\bar{\alpha_t}}\bold{x_0}+\sqrt{1-\bar{\alpha_t}}\epsilon, t)||^2

def q_sample(x_start, t, noise=None):
	if noise is None:
    	noise = torch.randn_like(x_start)

	sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
    	sqert_one_minus_alphas_cumprod, t, x_start.shape
    )
    
    return sqrt_alphas_cumpord_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

2-3) Loss Function을 정의한다.

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):

	if noise is None:
    	noise = torch.randn_like(x_start)

	# noised image
	x_noisy = q_sample(x_start=x_start, t=t, noise=noise)

	# predicted noise. 여기서 denoise_model은 U-Net
	predicted_noise = denoise_model(x_noisy, t)

	# l1_loss를 사용할 경우
	loss = F.l1_loss(noise, predicted_noise)

return loss

3. Reverse Process를 정의한다.

1) time step TT의 pure noise에서 Gaussian distribution을 sampling한다.
xtN(0,I)\bold{x_t} \sim N(\bold{0,I})

2) denoising
time step T부터 0까지, 신경망을 사용해서, 신경망이 학습한 conditional probability를 이용해 점진적으로 denoise한다.
mean을reparameterization하여 만든 우리의 noise predictor를 이용해서,
우리는 조금 더 denoised 된 image xt1\bold{x_{t−1}} 을 얻을 수 있다.
이 과정을 통해 real data distribution q(x0)\bold{q(x_0)}에서 생성된 것과 유사한 새로운 image를 얻게된다.

@torch.no_grad()
def p_sample(model, x, t, t_index):
	betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
    	sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    
    # 1/sqrt(\alpha_t)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
    	x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    
    if t_index == 0:
    	return model_mean
	else:
    	posterior_variance_t = extract(
        	posterior_variance, t, x.shape
        )
        noise = torch.randn_like(x)
        # Algorithm2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise

4. Model Training

epochs = 6

for epoch in range(epochs):
	for step, batch in enumerate(dataloader):
    	optimizer.zero_grad()

        # Algorithm 1 line 3: sample t uniformally for every example in the batch
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()

		loss = p_losses(denoise_model=model, x_start=batch, t=t, loss_type="huber")
        
        if step % 100 == 0:
        	print("Loss:", loss.item())
		
        loss.backward()
        optimizer.step()

5. Sampling(Inference)

# Algorithm 2 (includint returning all images)
@torch.no_grad()
def p_sample_loop(model, shape):
	device = next(model.parameters()).device
    
    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []
    
    for i in tqdm(reversed(range(0, timesteps)), desc = 'sampling loop time step', total=timesteps):
    	img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
    	imgs.append(img.cpu().numpy())
	return imgs

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
	return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

결과 시각화

# show a random one
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")
post-custom-banner

0개의 댓글