Paper : Barbershop: GAN-based Image Compositing using Segmentation Maskss
align_face.py
: 이미지를 다운로드하고, 얼굴을 정렬한 후, 정렬된 얼굴 이미지를 저장
cache_dir = Path(args.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True,exist_ok=True)
print("Downloading Shape Predictor")
f=open_url("https://drive.google.com/uc?id=1huhv8PYpNNKbGCLOaYUjOgR1pY5pmbJx", cache_dir=cache_dir, return_path=True)
predictor = dlib.shape_predictor(f)
for im in Path(args.unprocessed_dir).glob("*.*"):
faces = align_face(str(im),predictor)
for i,face in enumerate(faces):
if(args.output_size):
factor = 1024//args.output_size
assert args.output_size*factor == 1024
face_tensor = torchvision.transforms.ToTensor()(face).unsqueeze(0).cuda()
face_tensor_lr = face_tensor[0].cpu().detach().clamp(0, 1)
face = torchvision.transforms.ToPILImage()(face_tensor_lr)
if factor != 1:
face = face.resize((args.output_size, args.output_size), PIL.Image.LANCZOS)
if len(faces) > 1:
face.save(Path(args.output_dir) / (im.stem+f"_{i}.png"))
else:
face.save(Path(args.output_dir) / (im.stem + f".png"))
align_face
: 얼굴 정렬을 수행하는 함수로, 얼굴의 특징점(landmarks)을 찾아서 얼굴이 중심에 오도록 이미지를 재조정. 이 함수는 입력으로 이미지 파일 경로(filepath)와 얼굴의 랜드마크를 찾는데 사용되는 모델(predictor)을 통해 정렬된 얼굴 이미지를 반환.main.py
: ① Embedding ② Alignment ③ Blending
from models.Embedding import Embedding
from models.Alignment import Alignment
from models.Blending import Blending
def main(args):
# ① Embedding
ii2s = Embedding(args)
im_path1 = os.path.join(args.input_dir, args.im_path1)
im_path2 = os.path.join(args.input_dir, args.im_path2)
im_path3 = os.path.join(args.input_dir, args.im_path3)
im_set = {im_path1, im_path2, im_path3}
ii2s.invert_images_in_W([*im_set])
ii2s.invert_images_in_FS([*im_set])
# ② Alignment
align = Alignment(args)
align.align_images(im_path1, im_path2, sign=args.sign, align_more_region=False, smooth=args.smooth)
if im_path2 != im_path3:
align.align_images(im_path1, im_path3, sign=args.sign, align_more_region=False, smooth=args.smooth, save_intermediate=False)
# ③ Blending
blend = Blending(args)
blend.blend_images(im_path1, im_path2, im_path3, sign=args.sign)
Embedding.py
ii2s = Embedding(args)
class Embedding(nn.Module):
def __init__(self, opts):
super(Embedding, self).__init__()
self.opts = opts
self.net = Net(self.opts)
self.load_downsampling()
self.setup_embedding_loss_builder()
def load_downsampling(self):
factor = self.opts.size // 256
self.downsample = BicubicDownSample(factor=factor)
def setup_embedding_loss_builder(self):
self.loss_builder = EmbeddingLossBuilder(self.opts)
ii2s.invert_images_in_W([*im_set])
def invert_images_in_W(self, image_path=None):
self.setup_dataloader(image_path=image_path)
device = self.opts.device
ibar = tqdm(self.dataloader, desc='Images')
for ref_im_H, ref_im_L, ref_name in ibar:
optimizer_W, latent = self.setup_W_optimizer()
pbar = tqdm(range(self.opts.W_steps), desc='Embedding', leave=False)
for step in pbar:
optimizer_W.zero_grad()
latent_in = torch.stack(latent).unsqueeze(0)
gen_im, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False)
im_dict = {
'ref_im_H': ref_im_H.to(device),
'ref_im_L': ref_im_L.to(device),
'gen_im_H': gen_im,
'gen_im_L': self.downsample(gen_im)
}
loss, loss_dic = self.cal_loss(im_dict, latent_in)
loss.backward()
optimizer_W.step()
if self.opts.verbose:
pbar.set_description('Embedding: Loss: {:.3f}, L2 loss: {:.3f}, Perceptual loss: {:.3f}, P-norm loss: {:.3f}'
.format(loss, loss_dic['l2'], loss_dic['percep'], loss_dic['p-norm']))
if self.opts.save_intermediate and step % self.opts.save_interval== 0:
self.save_W_intermediate_results(ref_name, gen_im, latent_in, step)
self.save_W_results(ref_name, gen_im, latent_in)
ii2s.invert_images_in_FS([*im_set])
def invert_images_in_FS(self, image_path=None):
self.setup_dataloader(image_path=image_path)
output_dir = self.opts.output_dir
device = self.opts.device
ibar = tqdm(self.dataloader, desc='Images')
for ref_im_H, ref_im_L, ref_name in ibar:
latent_W_path = os.path.join(output_dir, 'W+', f'{ref_name[0]}.npy')
latent_W = torch.from_numpy(convert_npy_code(np.load(latent_W_path))).to(device)
F_init, _ = self.net.generator([latent_W], input_is_latent=True, return_latents=False, start_layer=0, end_layer=3)
optimizer_FS, latent_F, latent_S = self.setup_FS_optimizer(latent_W, F_init)
pbar = tqdm(range(self.opts.FS_steps), desc='Embedding', leave=False)
for step in pbar:
optimizer_FS.zero_grad()
latent_in = torch.stack(latent_S).unsqueeze(0)
gen_im, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False,
start_layer=4, end_layer=8, layer_in=latent_F)
im_dict = {
'ref_im_H': ref_im_H.to(device),
'ref_im_L': ref_im_L.to(device),
'gen_im_H': gen_im,
'gen_im_L': self.downsample(gen_im)
}
loss, loss_dic = self.cal_loss(im_dict, latent_in)
loss.backward()
optimizer_FS.step()
if self.opts.verbose:
pbar.set_description(
'Embedding: Loss: {:.3f}, L2 loss: {:.3f}, Perceptual loss: {:.3f}, P-norm loss: {:.3f}, L_F loss: {:.3f}'
.format(loss, loss_dic['l2'], loss_dic['percep'], loss_dic['p-norm'], loss_dic['l_F']))
self.save_FS_results(ref_name, gen_im, latent_in, latent_F)
Alignment.py
align = Alignment(args)
class Alignment(nn.Module):
def __init__(self, opts, net=None):
super(Alignment, self).__init__()
self.opts = opts
if not net:
self.net = Net(self.opts)
else:
self.net = net
self.load_segmentation_network()
self.load_downsampling()
self.setup_align_loss_builder()
def load_segmentation_network(self):
self.seg = BiSeNet(n_classes=16)
self.seg.to(self.opts.device)
if not os.path.exists(self.opts.seg_ckpt):
download_weight(self.opts.seg_ckpt)
self.seg.load_state_dict(torch.load(self.opts.seg_ckpt))
for param in self.seg.parameters():
param.requires_grad = False
self.seg.eval()
def load_downsampling(self):
self.downsample = BicubicDownSample(factor=self.opts.size // 512)
self.downsample_256 = BicubicDownSample(factor=self.opts.size // 256)
def setup_align_loss_builder(self):
self.loss_builder = AlignLossBuilder(self.opts)
align.align_images(im_path1, im_path2, sign=args.sign, align_more_region=False, smooth=args.smooth)
: 두 이미지의 특성을 반영하여 최적의 정렬을 수행
def align_images(self, img_path1, img_path2, sign='realistic', align_more_region=False, smooth=5,
save_intermediate=True):
################## img_path1: Identity Image
################## img_path2: Structure Image
device = self.opts.device
output_dir = self.opts.output_dir
target_mask, hair_mask_target, hair_mask1, hair_mask2 = \
self.create_target_segmentation_mask(img_path1=img_path1, img_path2=img_path2, sign=sign,
save_intermediate=save_intermediate)
im_name_1 = os.path.splitext(os.path.basename(img_path1))[0]
im_name_2 = os.path.splitext(os.path.basename(img_path2))[0]
latent_FS_path_1 = os.path.join(output_dir, 'FS', f'{im_name_1}.npz')
latent_FS_path_2 = os.path.join(output_dir, 'FS', f'{im_name_2}.npz')
latent_1, latent_F_1 = load_FS_latent(latent_FS_path_1, device)
latent_2, latent_F_2 = load_FS_latent(latent_FS_path_2, device)
latent_W_path_1 = os.path.join(output_dir, 'W+', f'{im_name_1}.npy')
latent_W_path_2 = os.path.join(output_dir, 'W+', f'{im_name_2}.npy')
save_intermediate
: 중간 결과를 저장할 것인지를 결정create_target_segmentation_mask
: 이미지의 특정 영역을 추출하는 데 사용되는 대상 세그멘테이션 마스크를 생성load_FS_latent
: 이미지의 특성을 불러옴 optimizer_align, latent_align_1 = self.setup_align_optimizer(latent_W_path_1)
pbar = tqdm(range(self.opts.align_steps1), desc='Align Step 1', leave=False)
for step in pbar:
optimizer_align.zero_grad()
latent_in = torch.cat([latent_align_1[:, :6, :], latent_1[:, 6:, :]], dim=1)
down_seg, _ = self.create_down_seg(latent_in)
loss_dict = {}
##### Cross Entropy Loss
ce_loss = self.loss_builder.cross_entropy_loss(down_seg, target_mask)
loss_dict["ce_loss"] = ce_loss.item()
loss = ce_loss
loss.backward()
optimizer_align.step()
intermediate_align, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False,
start_layer=0, end_layer=3)
intermediate_align = intermediate_align.clone().detach()
setup_align_optimizer
: 이미지 정렬을 위한 최적화기와 정렬 대상 이미지의 특성을 설정create_down_seg
: 주어진 latent code를 이용하여 이미지를 생성하고, 생성된 이미지의 세그멘테이션을 생성하는 과정을 수행 optimizer_align, latent_align_2 = self.setup_align_optimizer(latent_W_path_2)
with torch.no_grad():
tmp_latent_in = torch.cat([latent_align_2[:, :6, :], latent_2[:, 6:, :]], dim=1)
down_seg_tmp, I_Structure_Style_changed = self.create_down_seg(tmp_latent_in)
current_mask_tmp = torch.argmax(down_seg_tmp, dim=1).long()
HM_Structure = torch.where(current_mask_tmp == 10, torch.ones_like(current_mask_tmp),
torch.zeros_like(current_mask_tmp))
HM_Structure = F.interpolate(HM_Structure.float().unsqueeze(0), size=(256, 256), mode='nearest')
pbar = tqdm(range(self.opts.align_steps2), desc='Align Step 2', leave=False)
for step in pbar:
optimizer_align.zero_grad()
latent_in = torch.cat([latent_align_2[:, :6, :], latent_2[:, 6:, :]], dim=1)
down_seg, gen_im = self.create_down_seg(latent_in)
Current_Mask = torch.argmax(down_seg, dim=1).long()
HM_G_512 = torch.where(Current_Mask == 10, torch.ones_like(Current_Mask),
torch.zeros_like(Current_Mask)).float().unsqueeze(0)
HM_G = F.interpolate(HM_G_512, size=(256, 256), mode='nearest')
loss_dict = {}
########## Segmentation Loss
ce_loss = self.loss_builder.cross_entropy_loss(down_seg, target_mask)
loss_dict["ce_loss"] = ce_loss.item()
loss = ce_loss
#### Style Loss
H1_region = self.downsample_256(I_Structure_Style_changed) * HM_Structure
H2_region = self.downsample_256(gen_im) * HM_G
style_loss = self.loss_builder.style_loss(H1_region, H2_region, mask1=HM_Structure, mask2=HM_G)
loss_dict["style_loss"] = style_loss.item()
loss += style_loss
loss.backward()
optimizer_align.step()
latent_F_out_new, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False,
start_layer=0, end_layer=3)
latent_F_out_new = latent_F_out_new.clone().detach()
: 특정 영역에 대한 마스크(원래 이미지와 목표 이미지의 특정 영역)를 생성하고, 두 이미지의 latent code를 섞음. 이 과정을 두번 반복하여 최종적으로 두 이미지가 섞인 latent code로 결합하여 최종 이미지를 생성하고 이를 저장.
free_mask = 1 - (1 - hair_mask1.unsqueeze(0)) * (1 - hair_mask_target)
##############################
free_mask, _ = self.dilate_erosion(free_mask, device, dilate_erosion=smooth)
##############################
free_mask_down_32 = F.interpolate(free_mask.float(), size=(32, 32), mode='bicubic')[0]
interpolation_low = 1 - free_mask_down_32
latent_F_mixed = intermediate_align + interpolation_low.unsqueeze(0) * (
latent_F_1 - intermediate_align)
if not align_more_region:
free_mask = hair_mask_target
##########################
_, free_mask = self.dilate_erosion(free_mask, device, dilate_erosion=smooth)
##########################
free_mask_down_32 = F.interpolate(free_mask.float(), size=(32, 32), mode='bicubic')[0]
interpolation_low = 1 - free_mask_down_32
latent_F_mixed = latent_F_out_new + interpolation_low.unsqueeze(0) * (
latent_F_mixed - latent_F_out_new)
free_mask = F.interpolate((hair_mask2.unsqueeze(0) * hair_mask_target).float(), size=(256, 256), mode='nearest').cuda()
##########################
_, free_mask = self.dilate_erosion(free_mask, device, dilate_erosion=smooth)
##########################
free_mask_down_32 = F.interpolate(free_mask.float(), size=(32, 32), mode='bicubic')[0]
interpolation_low = 1 - free_mask_down_32
latent_F_mixed = latent_F_2 + interpolation_low.unsqueeze(0) * (
latent_F_mixed - latent_F_2)
gen_im, _ = self.net.generator([latent_1], input_is_latent=True, return_latents=False, start_layer=4,
end_layer=8, layer_in=latent_F_mixed)
self.save_align_results(im_name_1, im_name_2, sign, gen_im, latent_1, latent_F_mixed,
save_intermediate=save_intermediate)
dilate_erosion
: 마스크의 크기를 조절하고 마스크를 부드럽게 만드는 함수save_align_results
: 생성된 특성을 이용하여 최종 이미지를 생성하고, 이를 저장Blending.py
blend = Blending(args)
class Blending(nn.Module):
def __init__(self, opts, net=None):
super(Blending, self).__init__()
self.opts = opts
if not net:
self.net = Net(self.opts)
else:
self.net = net
self.load_segmentation_network()
self.load_downsampling()
self.setup_blend_loss_builder()
def load_segmentation_network(self):
self.seg = BiSeNet(n_classes=16)
self.seg.to(self.opts.device)
if not os.path.exists(self.opts.seg_ckpt):
download_weight(self.opts.seg_ckpt)
self.seg.load_state_dict(torch.load(self.opts.seg_ckpt))
for param in self.seg.parameters():
param.requires_grad = False
self.seg.eval()
def load_downsampling(self):
self.downsample = BicubicDownSample(factor=self.opts.size // 512)
self.downsample_256 = BicubicDownSample(factor=self.opts.size // 256)
def setup_blend_loss_builder(self):
self.loss_builder = BlendLossBuilder(self.opts)
blend.blend_images(im_path1, im_path2, im_path3, sign=args.sign)
: 두 이미지의 특성을 합성하고, 세 번째 이미지의 특성을 더하는 블렌딩 함수
def blend_images(self, img_path1, img_path2, img_path3, sign='realistic'):
device = self.opts.device
output_dir = self.opts.output_dir
im_name_1 = os.path.splitext(os.path.basename(img_path1))[0]
im_name_2 = os.path.splitext(os.path.basename(img_path2))[0]
im_name_3 = os.path.splitext(os.path.basename(img_path3))[0]
I_1 = load_image(img_path1, downsample=True).to(device).unsqueeze(0)
I_3 = load_image(img_path3, downsample=True).to(device).unsqueeze(0)
HM_1D, _ = cuda_unsqueeze(dilate_erosion_mask_path(img_path1, self.seg), device)
HM_3D, HM_3E = cuda_unsqueeze(dilate_erosion_mask_path(img_path3, self.seg), device)
opt_blend, interpolation_latent = self.setup_blend_optimizer()
latent_1, latent_F_mixed = load_FS_latent(os.path.join(output_dir, 'Align_{}'.format(sign),
'{}_{}.npz'.format(im_name_1, im_name_3)),device)
latent_3, _ = load_FS_latent(os.path.join(output_dir, 'FS',
'{}.npz'.format(im_name_3)), device)
with torch.no_grad():
I_X, _ = self.net.generator([latent_1], input_is_latent=True, return_latents=False, start_layer=4,
end_layer=8, layer_in=latent_F_mixed)
I_X_0_1 = (I_X + 1) / 2
IM = (self.downsample(I_X_0_1) - seg_mean) / seg_std
down_seg, _, _ = self.seg(IM)
current_mask = torch.argmax(down_seg, dim=1).long().cpu().float()
HM_X = torch.where(current_mask == 10, torch.ones_like(current_mask), torch.zeros_like(current_mask))
HM_X = F.interpolate(HM_X.unsqueeze(0), size=(256, 256), mode='nearest').squeeze()
HM_XD, _ = cuda_unsqueeze(dilate_erosion_mask_tensor(HM_X), device)
target_mask = (1 - HM_1D) * (1 - HM_3D) * (1 - HM_XD)
pbar = tqdm(range(self.opts.blend_steps), desc='Blend', leave=False)
for step in pbar:
opt_blend.zero_grad()
latent_mixed = latent_1 + interpolation_latent.unsqueeze(0) * (latent_3 - latent_1)
I_G, _ = self.net.generator([latent_mixed], input_is_latent=True, return_latents=False, start_layer=4,
end_layer=8, layer_in=latent_F_mixed)
I_G_0_1 = (I_G + 1) / 2
im_dict = {
'gen_im': self.downsample_256(I_G),
'im_1': I_1,
'im_3': I_3,
'mask_face': target_mask,
'mask_hair': HM_3E
}
loss, loss_dic = self.loss_builder(**im_dict)
loss.backward()
opt_blend.step()
############## Load F code from '{}_{}.npz'.format(im_name_1, im_name_2)
_, latent_F_mixed = load_FS_latent(os.path.join(output_dir, 'Align_{}'.format(sign),
'{}_{}.npz'.format(im_name_1, im_name_2)), device)
I_G, _ = self.net.generator([latent_mixed], input_is_latent=True, return_latents=False, start_layer=4,
end_layer=8, layer_in=latent_F_mixed)
self.save_blend_results(im_name_1, im_name_2, im_name_3, sign, I_G, latent_mixed, latent_F_mixed)