Diffusion Model로 MNIST 생성해보기2

Tetrapod·2024년 6월 17일

Diffusion 관련

목록 보기
2/4

이 글에서는 diffusion 모델에 class 조건을 부여하여 원하는 MNIST 글자를 생성해 보고자한다.


MNIST 데이터셋 받기

  • torchvision을 이용하여 MNIST 데이터셋을 다운받는다.
dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

모델 정의

  • 모델에 class 정보를 추가로 입력받기 위해 Embedding layer가 추가되었다.
class ClassConditionedUnet(nn.Module):
  def __init__(self, num_classes=10, class_emb_size=4):
    super().__init__()

    self.class_emb = nn.Embedding(num_classes, class_emb_size)

    self.model = UNet2DModel(
        sample_size=28,
        in_channels=1 + class_emb_size,
        out_channels=1,
        layers_per_block=2,
        block_out_channels=(32, 64, 64),
        down_block_types=(
            "DownBlock2D",
            "AttnDownBlock2D",
            "AttnDownBlock2D",
        ),
        up_block_types=(
            "AttnUpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D",
          ),
    )

  def forward(self, x, t, class_labels):
    bs, ch, w, h = x.shape

    class_cond = self.class_emb(class_labels) # (bs, 4)
    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)

    net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)
    return self.model(net_input, t).sample # (bs, 1, 28, 28)
  • 아래 코드를 확인하면 클래스의 임베딩이 이미지 x의 channel에 이어붙는 것을 알 수 있다.
  • 그래서 UNet2DModel의 in_channels에도 클래스 임베딩 사이즈만큼 더해서 들어간다.
class_cond = self.class_emb(class_labels) # (bs, 4)
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)

학습 프로세스

  • 학습 프로세스는 아래 코드와 같이 y가 추가되는 것 이외는 전부 같다.
pred = net(noisy_x, timesteps, y)

Diffusion 프로세스

  • 각 클래스 당 8개씩 총 80개의 이미지를 생성해 보고자한다.
  • 마찬가지로 모델에 y가 추가로 입력된다.
x = torch.randn(80, 1, 28, 28).to(device)
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)

for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
    with torch.no_grad():
        residual = net(x, t, y)
    x = noise_scheduler.step(residual, t, x).prev_sample

fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap='Greys')
  • 아래는 timesteps로 1000을 준 Diffusion 프로세스 결과이다.


Reference

0개의 댓글