이전 논문 리뷰 후 구현을 해보았다.
전체 코드는 깃허브 링크를 보면 된다.
우선 전체 학습 구현은 위의 알고리즘이 전부이다.
model은 생략하겠다. model에 대한 논문이 아니기 때문
model은 U-Net에 time embedding을 활용한 것을 가져와서 base로 사용하였다.
우선 학습을 위해서 필요한 beta, alpha를 미리 정의해둔다.
T = 1000
beta_start = 1e-4
beta_end = 0.02
betas = torch.linspace(beta_start, beta_end, T).to(device)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).to(device)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod).to(device)
위 의사코드를 보고 이해하면 매우 쉽다.
for epoch in tqdm(range(epochs)):
total_loss = 0
for i, (x, _) in enumerate(train_loader):
x = x.to(device)
batch_size = x.size(0)
t = torch.randint(0, T, (batch_size,)).to(device)
noise = torch.randn_like(x).to(device)
# 이미지에 노이즈 추가
sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t].reshape(x.shape[0], 1, 1, 1)
sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].reshape(
x.shape[0], 1, 1, 1)
# reparametrization trick
noise_x = sqrt_alphas_cumprod_t * x + sqrt_one_minus_alphas_cumprod_t * noise
noise_pred = model(noise_x, t)
noise_loss = criterion(noise_pred, noise)
optimizer.zero_grad()
noise_loss.backward()
optimizer.step()
total_loss += noise_loss.item()
print(f"Epoch [{epoch}], Loss: {total_loss/len(train_loader)}")
설명하자면 임의의 time t를 뽑고 로 빠르게 만든다.
이를 model에 넣어서 noise를 예측하고 실제 noise와 mseloss를 이용해서 학습한다.
사실 이게 끝이다. 매우 간단하다.
역시 위 의사코드가 전부이다.
@torch.no_grad()
def sample(model, img_size, alphas, alphas_cumprod, betas, T, device, batch_size=64):
model.eval()
x = torch.randn(batch_size, 1, img_size, img_size).to(device) # 순수 노이즈로부터 시작
noise_to_x = [x]
for t in reversed(range(T)):
# 현재 타임스텝 t
t_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)
# 노이즈 예측합니다.
noise_pred = model(x, t_tensor)
beta_t = betas[t].to(device)
# 이전 x를 계산합니다.
if t > 0:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
x = (1 / torch.sqrt(alphas[t])) * (x - ((1 - alphas[t]) / torch.sqrt(1 - alphas_cumprod[t])) * noise_pred) + torch.sqrt(beta_t) * noise
return x
위처럼 구현이 되는데 처음 random noise
x = torch.randn(batch_size, 1, img_size, img_size).to(device)
에서 시작해서 T를 reverse로 복구한다.
model에 현재의 noise된 image를 넣어서 noise를 예측하고
기존에 유도된 수식에 따라서
x = (1 / torch.sqrt(alphas[t])) * (x - ((1 - alphas[t]) / torch.sqrt(1 - alphas_cumprod[t])) * noise_pred) + torch.sqrt(beta_t) * noise
다음 x를 이렇게 만들어낸다.
사실 sampling도 이게 전부이다.
의외로 매우 간단했다.