이 글에서는 diffusion 모델에 class 조건을 부여하여 원하는 MNIST 글자를 생성해 보고자한다.
dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
class ClassConditionedUnet(nn.Module):
def __init__(self, num_classes=10, class_emb_size=4):
super().__init__()
self.class_emb = nn.Embedding(num_classes, class_emb_size)
self.model = UNet2DModel(
sample_size=28,
in_channels=1 + class_emb_size,
out_channels=1,
layers_per_block=2,
block_out_channels=(32, 64, 64),
down_block_types=(
"DownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
),
)
def forward(self, x, t, class_labels):
bs, ch, w, h = x.shape
class_cond = self.class_emb(class_labels) # (bs, 4)
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)
return self.model(net_input, t).sample # (bs, 1, 28, 28)
class_cond = self.class_emb(class_labels) # (bs, 4)
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)
pred = net(noisy_x, timesteps, y)
x = torch.randn(80, 1, 28, 28).to(device)
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
with torch.no_grad():
residual = net(x, t, y)
x = noise_scheduler.step(residual, t, x).prev_sample
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap='Greys')
