- 이 글에서는 Diffusion의 간단한 개념과 Diffusion 과정을 MNIST 생성을 통해 다뤄보고자 한다.
- DDPM 모델 기반으로 다룬다.
DDPM(Denoising Diffusion Probabilistic Models) :
LDM(Latent Diffusion Models) :
노이즈 제거 학습:
모델은 노이즈가 무작위로 추가된 데이터를 원래의 데이터로 복원하는 방법을 학습한다.
데이터 생성:
학습이 완료된 후에는 모델이 랜덤한 노이즈에서 시작하여 점진적으로 노이즈를 제거하면서 새로운 데이터를 생성할 수 있다.

class BasicUNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.down_layers = torch.nn.ModuleList([
nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
])
self.up_layers = torch.nn.ModuleList([
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 32, kernel_size=5, padding=2),
nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
])
self.act = nn.SiLU()
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = self.act(l(x))
if i < 2:
h.append(x)
x = self.downscale(x)
for i, l in enumerate(self.up_layers):
if i > 0:
x = self.upscale(x)
x += h.pop()
x = self.act(l(x))
return x
noise = torch.rand_like(x)
noisy_x = (1-amount)*x + amount*noise
def corrupt(x, amount):
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
return x*(1-amount) + noise*amount
# Plotting the input data
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
# Adding noise
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)
# Plotting the noised version
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');

# clean x와 noisy x를 준비한다.
x = x.to(device)
noise_amount = torch.rand(x.shape[0]).to(device) # 0~1
noisy_x = corrupt(x, noise_amount)
# noisy x -> denoised x
pred = net(noisy_x)
# Calculate the MSE loss
loss = loss_fn(pred, x)
# Backprop and update
opt.zero_grad()
loss.backward()
opt.step()

n_steps = 5
x = torch.rand(8, 1, 28, 28).to(device) # x는 완전 노이즈에서 시작
step_history = [x.detach().cpu()]
pred_output_history = []
for step in range(n_steps):
with torch.no_grad():
pred = net(x) # 현재 x를 denoise
mix_factor = 1/(n_steps - step) # pred와 현재 x 비율 계산
x = x*(1-mix_factor) + pred*mix_factor # 현재 step이 클수록 x비율이 작아짐
pred_output_history.append(pred.detach().cpu())
step_history.append(x.detach().cpu())
fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('model prediction')
for i in range(n_steps):
axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap='Greys')
axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap='Greys')
