We'll go over the original DDPM paper by (Ho et al., 2020), implementing it step-by-step in PyTorch, based on Phil Wang's implementation - which itself is based on the original TensorFlow implementation.
Alright, let's dive in!
Reference:
In a bit more detail for images, the set-up consists of 2 processes:
Both the forward and reverse process indexed by \(t\) happen for some number of finite time steps \(T\)
Let's write this down more formally, as ultimately we need a tractable loss function which our neural network needs to optimize.
Forward Diffusion Process
Let \(q(\mathbf{x}_0)\) be the real data distribution, say of "real images".
\(\mathbf{x}_0 \sim q(\mathbf{x}_0)\): We can sample from this distribution to get an image.
\(\beta_t\) defines a so-called "variance schedule": linear, quadratic, cosine, etc.
We define the forward diffusion process \(q(\mathbf{x}t | \mathbf{x}{t-1})\) which adds Gaussian noise at each time step \(t\), according to a known variance schedule \(0 < \beta1 < \beta_2 < ... < \beta_T < 1\) as \(q(\mathbf{x}_t | \mathbf{x}{t-1}) = \mathcal{N}(\mathbf{x}t; \sqrt{1 - \beta_t} \mathbf{x}{t-1}, \beta_t \mathbf{I}\)
We can parametrize \(\mathbf{x}t = \sqrt{1 - \beta_t} \mathbf{x}{t-1} + \sqrt{\beta_t} \mathbf{\epsilon}\) by sampling \(\mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})\)
Reverse Diffusion Process
Now, if we knew the conditional distribution \(p(\mathbf{x}_{t-1} | \mathbf{x}_t)\), then we could run the process in reverse: by sampling some random Gaussian noise \(\mathbf{x}_T\), and then gradually "denoise" it so that we end up with a sample from the real distribution \(\mathbf{x}_0\).
However, we don't know \(p(\mathbf{x}_{t-1} | \mathbf{x}_t)\). It's intractable since it requires knowing the distribution of all possible images in order to calculate this conditional probability.
Hence, we're going to leverage a neural network to approximate (learn) this conditional probability distribution, let's call it \(p\theta (\mathbf{x}{t-1} | \mathbf{x}_t)\), with \(\theta\) being the parameters of the neural network, updated by gradient descent.
Neural Network \(\mathbf{\epsilon}_\theta(\mathbf{x}_t, t)\)
Ok, so we need a neural network to represent a (conditional) probability distribution of the backward process.
If we assume this reverse process is Gaussian as well, then recall that any Gaussian distribution is defined by 2 parameters:
so we can parametrize the process as
$$ p\theta (\mathbf{x}{t-1} | \mathbf{x}t) = \mathcal{N}(\mathbf{x}{t-1}; \mu\theta(\mathbf{x}{t},t), \Sigma\theta (\mathbf{x}{t},t))$$
where the mean and variance are also conditioned on the noise level \(t\).
In DDPM, the neural network only learn (represent) the mean \(\mu_\theta\) of this conditional probability distribution
The variational lower bound (also called ELBO) can be used to minimize the negative log-likelihood with respect to ground truth data sample \(\mathbf{x}_0\).
The ELBO for this process is a sum of losses at each time step \(t\), \(L = L_0 + L_1 + ... + L_T\).(please refer to VAE for further details regarding ELBO).
Each term (except for \(L_0\)) of the loss is actually the KL divergence between 2 Gaussian distributions which can be written explicitly as an L2-loss with respect to the means!
During training, we can optimize random terms of the loss function \(L\) by its "nice property"
In other words, to randomly sample \(t\) during training and optimize \(L_t\).
We can sample \(\mathbf{x}t\) at any arbitrary noise level conditioned on \(\mathbf{x}_0\) (since sums of Gaussians is also Gaussian).
with \(\alpha_t := 1 - \beta_t\) and \(\bar{\alpha}t := \Pi{s=1}^{t} \alpha_s\)
Moreover, we can instead reparametrize the mean to make the neural network learn (predict) the added noise.
$$ \mathbf{\mu}\theta(\mathbf{x}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1- \bar{\alpha}_t}} \mathbf{\epsilon}\theta(\mathbf{x}_t, t) \right)$$
The training algorithm now looks as follows:
In other words:
As can be seen, a U-Net model first downsamples the input (i.e. makes the input smaller in terms of spatial resolution), after which upsampling is performed.
Below, we implement this network, step-by-step.
!pip install -q -U einops datasets matplotlib tqdm
import math
from inspect import isfunction
from functools import partial
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange
import torch
from torch import nn, einsum
import torch.nn.functional as F
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m469.0/469.0 KB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.2/212.2 KB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.9/132.9 KB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 KB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m264.6/264.6 KB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m158.8/158.8 KB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.2/114.2 KB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25h
Residual
module simply adds the input to the output of a particular function (in other words, adds a residual connection to a particular function).def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
SinusoidalPositionEmbeddings
module takes (batch_size, 1)
(i.e. the noise levels of several noisy images in a batch), (batch_size, dim)
, with dim
being the dimensionality of the position embeddings. class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
The core building block of the U-Net model.
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
"""https://arxiv.org/abs/1512.03385"""
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
super().__init__()
self.mlp = (
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
if exists(time_emb_dim)
else None
)
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
h = self.block1(x)
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
h = rearrange(time_emb, "b c -> b c 1 1") + h
h = self.block2(h)
return h + self.res_conv(x)
class ConvNextBlock(nn.Module):
"""https://arxiv.org/abs/2201.03545"""
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
super().__init__()
self.mlp = (
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
if exists(time_emb_dim)
else None
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
self.net = nn.Sequential(
nn.GroupNorm(1, dim) if norm else nn.Identity(),
nn.Conv2d(dim, dim_out * mult, 3, padding=1),
nn.GELU(),
nn.GroupNorm(1, dim_out * mult),
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
h = self.ds_conv(x)
if exists(self.mlp) and exists(time_emb):
assert exists(time_emb), "time embedding must be passed in"
condition = self.mlp(time_emb)
h = h + rearrange(condition, "b c -> b c 1 1")
h = self.net(h)
return h + self.res_conv(x)
The attention module, which the DDPM authors added in between the convolutional blocks.
class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q = q * self.scale
sim = einsum("b h d i, b h d j -> b h i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("b h i j, b h d j -> b h i d", attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
return self.to_out(out)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
nn.GroupNorm(1, dim))
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q = q.softmax(dim=-2)
k = k.softmax(dim=-1)
q = q * self.scale
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
return self.to_out(out)
PreNorm
class will be used to apply groupnorm before the attention layer, as we'll see further. class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.GroupNorm(1, dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
Now, we've defined all building blocks (position embeddings, ResNet/ConvNeXT blocks, attention and group normalization).
It's time to define the entire neural network.
(batch_size, num_channels, height, width)
and a batch of noise levels of shape (batch_size, 1)
as input, and returns a tensor of shape (batch_size, num_channels, height, width)
The network is built up as follows:
1. A convolutional layer is applied on the batch of noisy images, and position embeddings are computed for the noise levels
1. A sequence of downsampling stages are applied.
Each downsampling stage consists of 2 ResNet/ConvNeXT blocks + groupnorm + attention + residual connection + a downsample operation
1. At the middle of the network, again ResNet or ConvNeXT blocks are applied, interleaved with attention
1. A sequence of upsampling stages are applied.
Each upsampling stage consists of 2 ResNet/ConvNeXT blocks + groupnorm + attention + residual connection + an upsample operation
1. Finally, a ResNet/ConvNeXT block followed by a convolutional layer is applied.
class Unet(nn.Module):
def __init__(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
with_time_emb=True,
resnet_block_groups=8,
use_convnext=True,
convnext_mult=2,
):
super().__init__()
# determine dimensions
self.channels = channels
init_dim = default(init_dim, dim // 3 * 2)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
if use_convnext:
block_klass = partial(ConvNextBlock, mult=convnext_mult)
else:
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
# time embeddings
if with_time_emb:
time_dim = dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)
else:
time_dim = None
self.time_mlp = None
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(
nn.ModuleList(
[
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity(),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(
nn.ModuleList(
[
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity(),
]
)
)
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
)
def forward(self, x, time):
x = self.init_conv(x)
t = self.time_mlp(time) if exists(self.time_mlp) else None
h = []
# downsample
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
# bottleneck
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
# upsample
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x)
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule as proposed in https://arxiv.org/abs/2102.09672
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def quadratic_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
def sigmoid_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
betas = torch.linspace(-6, 6, timesteps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
To start with,
extract
function, which will allow us to extract the appropriate \(t\) index for a batch of indices.timesteps = 200
# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)
# define alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
from PIL import Image
import requests
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize
import numpy as np
# forward diffusion
def q_sample(x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
##################### Quiz #####################
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
################################################
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
if noise is None:
noise = torch.randn_like(x_start)
##################### Quiz #####################
# Hint: see the first figure!
x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
predicted_noise = denoise_model(x_noisy, t)
################################################
if loss_type == 'l1':
loss = F.l1_loss(noise, predicted_noise)
elif loss_type == 'l2':
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "huber":
loss = F.smooth_l1_loss(noise, predicted_noise)
else:
raise NotImplementedError()
return loss
Here we define a regular PyTorch Dataset. The dataset simply consists of images from a real dataset, like Fashion-MNIST, CIFAR-10 or ImageNet, scaled linearly to \([−1, 1]\).
Here we use the 🤗 Datasets library to easily load the Fashion MNIST dataset from the hub. This dataset consists of images which already have the same resolution, namely 28x28.
from datasets import load_dataset
# load dataset from the hub
dataset = load_dataset("fashion_mnist")
image_size = 28
channels = 1
batch_size = 128
Downloading builder script: 0%| | 0.00/4.83k [00:00<?, ?B/s]
Downloading metadata: 0%| | 0.00/3.13k [00:00<?, ?B/s]
Downloading readme: 0%| | 0.00/8.85k [00:00<?, ?B/s]
Downloading and preparing dataset fashion_mnist/fashion_mnist to /root/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1...
Downloading data files: 0%| | 0/4 [00:00<?, ?it/s]
Downloading data: 0%| | 0.00/26.4M [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/29.5k [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/4.42M [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/5.15k [00:00<?, ?B/s]
Extracting data files: 0%| | 0/4 [00:00<?, ?it/s]
Generating train split: 0%| | 0/60000 [00:00<?, ? examples/s]
Generating test split: 0%| | 0/10000 [00:00<?, ? examples/s]
Dataset fashion_mnist downloaded and prepared to /root/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1. Subsequent calls will reuse this data.
0%| | 0/2 [00:00<?, ?it/s]
Next, we define a function which we'll apply on-the-fly on the entire dataset. We use the with_transform
functionality for that. The function just applies some basic image preprocessing: random horizontal flips, rescaling and finally make them have values in the range.
from torchvision import transforms
from torch.utils.data import DataLoader
# define image transformations (e.g. using torchvision)
transform = Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Lambda(lambda t: (t * 2) - 1)
])
# define function
def transforms(examples):
examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
del examples["image"]
return examples
transformed_dataset = dataset.with_transform(transforms).remove_columns("label")
# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)
batch = next(iter(dataloader))
print(batch.keys())
dict_keys(['pixel_values'])
Sampling is summarized in the paper as Algorithm 2:
Generating new images from a diffusion model happens by reversing the diffusion process
We start from , where we sample pure noise from a Gaussian distribution, and then use our neural network to gradually denoise it (using the conditional probability it has learned), until we end up at time step .
As shown above, we can derive a slighly less denoised image by plugging in the reparametrization of the mean, using our noise predictor (remember that the variance is known ahead of time).
Ideally, we end up with an image that looks like it came from the real data distribution.
The code below implements this.
@torch.no_grad()
def p_sample(model, x, t, t_index):
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
##################### Quiz #####################
# Equation 11 in the paper
# Use our model (noise predictor) to predict the mean
model_mean = sqrt_recip_alphas_t * (
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
)
################################################
if t_index == 0:
return model_mean
else:
posterior_variance_t = extract(posterior_variance, t, x.shape)
noise = torch.randn_like(x)
##################### Quiz #####################
# Hint: Algorithm 2 line 4:
return model_mean * torch.sqrt(posterior_variance_t) * noise
################################################
# Algorithm 2 but save all images:
@torch.no_grad()
def p_sample_loop(model, shape):
device = next(model.parameters()).device
b = shape[0]
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
imgs.append(img.cpu().numpy())
return imgs
Next, we train the model in regular PyTorch fashion.
Below, we define the model, and move it to the GPU. We also define a standard optimizer (Adam).
from torch.optim import Adam
from pathlib import Path
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Unet(
dim=image_size,
channels=channels,
dim_mults=(1, 2, 4,)
)
model.to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
Let's start training!
from torchvision.utils import save_image
epochs = 5
for epoch in range(epochs):
print(f'\nepoch: {epoch}')
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
batch_size = batch["pixel_values"].shape[0]
batch = batch["pixel_values"].to(device)
# Algorithm 1 line 3: sample t uniformally for every example in the batch
t = torch.randint(0, timesteps, (batch_size,), device=device).long()
loss = p_losses(model, batch, t, loss_type="huber")
if step % 100 == 0:
print(f"\tLoss: {loss.item():.4f} / step: {step:4d}")
loss.backward()
optimizer.step()
epoch: 0
Loss: 0.4535 / step: 0
Loss: 0.1220 / step: 100
Loss: 0.0736 / step: 200
Loss: 0.0643 / step: 300
Loss: 0.0643 / step: 400
epoch: 1
Loss: 0.0532 / step: 0
Loss: 0.0469 / step: 100
Loss: 0.0501 / step: 200
Loss: 0.0401 / step: 300
Loss: 0.0437 / step: 400
epoch: 2
Loss: 0.0494 / step: 0
Loss: 0.0439 / step: 100
Loss: 0.0424 / step: 200
Loss: 0.0471 / step: 300
Loss: 0.0458 / step: 400
epoch: 3
Loss: 0.0512 / step: 0
Loss: 0.0410 / step: 100
Loss: 0.0415 / step: 200
Loss: 0.0518 / step: 300
Loss: 0.0428 / step: 400
epoch: 4
Loss: 0.0463 / step: 0
Loss: 0.0493 / step: 100
Loss: 0.0428 / step: 200
Loss: 0.0430 / step: 300
Loss: 0.0460 / step: 400
To sample from the model, we can just use our sample function defined above:
# sample 64 images
sampling_batch_size = 64
img_shape = (sampling_batch_size, channels, image_size, image_size)
samples = p_sample_loop(model, shape=img_shape)
sampling loop time step: 0%| | 0/200 [00:00<?, ?it/s]
Keep in mind that the dataset we trained on is pretty low-resolution (28x28).
We can also create a gif of the denoising process:
from matplotlib import animation, rc
random_index = np.random.randint(sampling_batch_size)
fig = plt.figure()
ims = []
for i in range(timesteps):
im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels).squeeze(), cmap="gray", animated=True)
ims.append([im])
animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
rc('animation', html='html5')
animate
Your browser does not support the video tag.
Seems like the model is capable of generating a nice T-shirt!
Note that the DDPM paper showed that diffusion models are a promising direction for (un)conditional image generation. This has since then (immensely) been improved, most notably for text-conditional image generation. Below, we list some important (but far from exhaustive) follow-up works:
Note that this list only includes important works until the time of writing, which is June 7th, 2022.
For now, it seems that the main (perhaps only) disadvantage of diffusion models is that they require multiple forward passes to generate an image (which is not the case for generative models like GANs). However, there's research going on that enables high-fidelity generation in as few as 10 denoising steps.
Reference
- AI504: Programming for AI Lecture at KAIST AI