
이 글에서는 diffusion 모델에 color guidance와 clip guidance를 적용해 보고자 한다.
pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms"
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
scheduler.set_timesteps(num_inference_steps=40)

def color_loss(images, target_color=(0.1, 0.9, 0.5)):
# Map target color to (-1, 1)
target = torch.tensor(target_color).to(images.device) * 2 - 1
target = target[None, :, None, None] # (b, c, h, w)
error = torch.abs(images - target).mean()
return error
guidance_loss_scale = 40
x = torch.randn(8, 3, 256, 256).to(device)
for i, t in tqdm(enumerate(scheduler.timesteps)):
model_input = scheduler.scale_model_input(x, t)
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
x = x.detach().requires_grad_()
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# Apply gradients
loss = color_loss(x0) * guidance_loss_scale
cond_grad = -torch.autograd.grad(loss, x)[0]
x = x.detach() + cond_grad
x = scheduler.step(noise_pred, t, x).prev_sample
# View the output
grid = torchvision.utils.make_grid(x, nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
Image.fromarray(np.array(im * 255).astype(np.uint8))
x = x.detach().requires_grad_()
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# Apply gradients
loss = color_loss(x0) * guidance_loss_scale
cond_grad = -torch.autograd.grad(loss, x)[0]
x = x.detach() + cond_grad

import open_clip
clip_model, _, preprocess = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="openai"
)
clip_model.to(device)
tfms = torchvision.transforms.Compose(
[
torchvision.transforms.RandomResizedCrop(224),
torchvision.transforms.RandomAffine(5),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
def clip_loss(image, text_features):
image_features = clip_model.encode_image(tfms(image)) # (batch, 512)
input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
# Squared Great Circle Distance
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
return dists.mean()
prompt = "Red Rose (still life), red flower painting"
guidance_scale = 8
n_cuts = 4
scheduler.set_timesteps(50)
#
text = open_clip.tokenize([prompt]).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = clip_model.encode_text(text) # (1, 512)
x = torch.randn(4, 3, 256, 256).to(device)
for i, t in tqdm(enumerate(scheduler.timesteps)):
# forward
model_input = scheduler.scale_model_input(x, t)
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
# cond_grad 평균 구하기
cond_grad = 0
for cut in range(n_cuts):
x = x.detach().requires_grad_()
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
loss = clip_loss(x0, text_features) * guidance_scale
cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts
if i % 25 == 0:
print("Step:", i, ", Guidance loss:", loss.item())
# x에 grad 적용
alpha_bar = scheduler.alphas_cumprod[i]
x = x.detach() + cond_grad * alpha_bar.sqrt()
x = scheduler.step(noise_pred, t, x).prev_sample
# 이미지 확인
grid = torchvision.utils.make_grid(x.detach(), nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
Image.fromarray(np.array(im * 255).astype(np.uint8))
text = open_clip.tokenize([prompt]).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = clip_model.encode_text(text) # (1, 512)
# cond_grad 평균 구하기
cond_grad = 0
for cut in range(n_cuts):
x = x.detach().requires_grad_()
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
loss = clip_loss(x0, text_features) * guidance_scale
cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts
