[Code Review] Barbershop: GAN-based Image Compositing using Segmentation Masks

E0u0n·2023년 12월 17일
0

GAN

목록 보기
2/2
post-thumbnail

Paper : Barbershop: GAN-based Image Compositing using Segmentation Maskss

Data preprocessing

align_face.py
: 이미지를 다운로드하고, 얼굴을 정렬한 후, 정렬된 얼굴 이미지를 저장

cache_dir = Path(args.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)

output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True,exist_ok=True)

print("Downloading Shape Predictor")
f=open_url("https://drive.google.com/uc?id=1huhv8PYpNNKbGCLOaYUjOgR1pY5pmbJx", cache_dir=cache_dir, return_path=True)
predictor = dlib.shape_predictor(f)

for im in Path(args.unprocessed_dir).glob("*.*"):
    faces = align_face(str(im),predictor)

    for i,face in enumerate(faces):
        if(args.output_size):
            factor = 1024//args.output_size
            assert args.output_size*factor == 1024
            face_tensor = torchvision.transforms.ToTensor()(face).unsqueeze(0).cuda()
            face_tensor_lr = face_tensor[0].cpu().detach().clamp(0, 1)
            face = torchvision.transforms.ToPILImage()(face_tensor_lr)
            if factor != 1:
                face = face.resize((args.output_size, args.output_size), PIL.Image.LANCZOS)
        if len(faces) > 1:
            face.save(Path(args.output_dir) / (im.stem+f"_{i}.png"))
        else:
            face.save(Path(args.output_dir) / (im.stem + f".png"))
  • align_face : 얼굴 정렬을 수행하는 함수로, 얼굴의 특징점(landmarks)을 찾아서 얼굴이 중심에 오도록 이미지를 재조정. 이 함수는 입력으로 이미지 파일 경로(filepath)와 얼굴의 랜드마크를 찾는데 사용되는 모델(predictor)을 통해 정렬된 얼굴 이미지를 반환.

Main

main.py
: ① Embedding ② Alignment ③ Blending

from models.Embedding import Embedding
from models.Alignment import Alignment
from models.Blending import Blending

def main(args):
	# ① Embedding
    ii2s = Embedding(args)

    im_path1 = os.path.join(args.input_dir, args.im_path1)
    im_path2 = os.path.join(args.input_dir, args.im_path2)
    im_path3 = os.path.join(args.input_dir, args.im_path3)

    im_set = {im_path1, im_path2, im_path3}
    ii2s.invert_images_in_W([*im_set])
    ii2s.invert_images_in_FS([*im_set])

	# ② Alignment
    align = Alignment(args)
    align.align_images(im_path1, im_path2, sign=args.sign, align_more_region=False, smooth=args.smooth)
    if im_path2 != im_path3:
        align.align_images(im_path1, im_path3, sign=args.sign, align_more_region=False, smooth=args.smooth, save_intermediate=False)
	
    # ③ Blending
    blend = Blending(args)
    blend.blend_images(im_path1, im_path2, im_path3, sign=args.sign)

① Embedding

Embedding.py
ii2s = Embedding(args)

class Embedding(nn.Module):

    def __init__(self, opts):
        super(Embedding, self).__init__()
        self.opts = opts
        self.net = Net(self.opts)
        self.load_downsampling()
        self.setup_embedding_loss_builder()
        

    def load_downsampling(self):
        factor = self.opts.size // 256
        self.downsample = BicubicDownSample(factor=factor)
        
    def setup_embedding_loss_builder(self):
        self.loss_builder = EmbeddingLossBuilder(self.opts)

ii2s.invert_images_in_W([*im_set])

    def invert_images_in_W(self, image_path=None):
        self.setup_dataloader(image_path=image_path)
        device = self.opts.device
        ibar = tqdm(self.dataloader, desc='Images')
        for ref_im_H, ref_im_L, ref_name in ibar:
            optimizer_W, latent = self.setup_W_optimizer()
            pbar = tqdm(range(self.opts.W_steps), desc='Embedding', leave=False)
            for step in pbar:
                optimizer_W.zero_grad()
                latent_in = torch.stack(latent).unsqueeze(0)

                gen_im, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False)
                im_dict = {
                    'ref_im_H': ref_im_H.to(device),
                    'ref_im_L': ref_im_L.to(device),
                    'gen_im_H': gen_im,
                    'gen_im_L': self.downsample(gen_im)
                }

                loss, loss_dic = self.cal_loss(im_dict, latent_in)
                loss.backward()
                optimizer_W.step()

                if self.opts.verbose:
                    pbar.set_description('Embedding: Loss: {:.3f}, L2 loss: {:.3f}, Perceptual loss: {:.3f}, P-norm loss: {:.3f}'
                                         .format(loss, loss_dic['l2'], loss_dic['percep'], loss_dic['p-norm']))

                if self.opts.save_intermediate and step % self.opts.save_interval== 0:
                    self.save_W_intermediate_results(ref_name, gen_im, latent_in, step)

            self.save_W_results(ref_name, gen_im, latent_in)

ii2s.invert_images_in_FS([*im_set])

    def invert_images_in_FS(self, image_path=None):
        self.setup_dataloader(image_path=image_path)
        output_dir = self.opts.output_dir
        device = self.opts.device
        ibar = tqdm(self.dataloader, desc='Images')
        for ref_im_H, ref_im_L, ref_name in ibar:

            latent_W_path = os.path.join(output_dir, 'W+', f'{ref_name[0]}.npy')
            latent_W = torch.from_numpy(convert_npy_code(np.load(latent_W_path))).to(device)
            F_init, _ = self.net.generator([latent_W], input_is_latent=True, return_latents=False, start_layer=0, end_layer=3)
            optimizer_FS, latent_F, latent_S = self.setup_FS_optimizer(latent_W, F_init)


            pbar = tqdm(range(self.opts.FS_steps), desc='Embedding', leave=False)
            for step in pbar:

                optimizer_FS.zero_grad()
                latent_in = torch.stack(latent_S).unsqueeze(0)
                gen_im, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False,
                                               start_layer=4, end_layer=8, layer_in=latent_F)
                im_dict = {
                    'ref_im_H': ref_im_H.to(device),
                    'ref_im_L': ref_im_L.to(device),
                    'gen_im_H': gen_im,
                    'gen_im_L': self.downsample(gen_im)
                }

                loss, loss_dic = self.cal_loss(im_dict, latent_in)
                loss.backward()
                optimizer_FS.step()

                if self.opts.verbose:
                    pbar.set_description(
                        'Embedding: Loss: {:.3f}, L2 loss: {:.3f}, Perceptual loss: {:.3f}, P-norm loss: {:.3f}, L_F loss: {:.3f}'
                        .format(loss, loss_dic['l2'], loss_dic['percep'], loss_dic['p-norm'], loss_dic['l_F']))

            self.save_FS_results(ref_name, gen_im, latent_in, latent_F)

② Alignment

Alignment.py
align = Alignment(args)

class Alignment(nn.Module):

    def __init__(self, opts, net=None):
        super(Alignment, self).__init__()
        self.opts = opts
        if not net:
            self.net = Net(self.opts)
        else:
            self.net = net

        self.load_segmentation_network()
        self.load_downsampling()
        self.setup_align_loss_builder()

    def load_segmentation_network(self):
        self.seg = BiSeNet(n_classes=16)
        self.seg.to(self.opts.device)

        if not os.path.exists(self.opts.seg_ckpt):
            download_weight(self.opts.seg_ckpt)
        self.seg.load_state_dict(torch.load(self.opts.seg_ckpt))
        for param in self.seg.parameters():
            param.requires_grad = False
        self.seg.eval()

    def load_downsampling(self):

        self.downsample = BicubicDownSample(factor=self.opts.size // 512)
        self.downsample_256 = BicubicDownSample(factor=self.opts.size // 256)

    def setup_align_loss_builder(self):
        self.loss_builder = AlignLossBuilder(self.opts)

align.align_images(im_path1, im_path2, sign=args.sign, align_more_region=False, smooth=args.smooth)
: 두 이미지의 특성을 반영하여 최적의 정렬을 수행

    def align_images(self, img_path1, img_path2, sign='realistic', align_more_region=False, smooth=5,
                     save_intermediate=True):

        ################## img_path1: Identity Image
        ################## img_path2: Structure Image

        device = self.opts.device
        output_dir = self.opts.output_dir
        target_mask, hair_mask_target, hair_mask1, hair_mask2 = \
            self.create_target_segmentation_mask(img_path1=img_path1, img_path2=img_path2, sign=sign,
                                                 save_intermediate=save_intermediate)

        im_name_1 = os.path.splitext(os.path.basename(img_path1))[0]
        im_name_2 = os.path.splitext(os.path.basename(img_path2))[0]

        latent_FS_path_1 = os.path.join(output_dir, 'FS', f'{im_name_1}.npz')
        latent_FS_path_2 = os.path.join(output_dir, 'FS', f'{im_name_2}.npz')

        latent_1, latent_F_1 = load_FS_latent(latent_FS_path_1, device)
        latent_2, latent_F_2 = load_FS_latent(latent_FS_path_2, device)

        latent_W_path_1 = os.path.join(output_dir, 'W+', f'{im_name_1}.npy')
        latent_W_path_2 = os.path.join(output_dir, 'W+', f'{im_name_2}.npy')
  • save_intermediate : 중간 결과를 저장할 것인지를 결정
  • create_target_segmentation_mask : 이미지의 특정 영역을 추출하는 데 사용되는 대상 세그멘테이션 마스크를 생성
  • load_FS_latent : 이미지의 특성을 불러옴

첫 번째 정렬 단계

        optimizer_align, latent_align_1 = self.setup_align_optimizer(latent_W_path_1)

        pbar = tqdm(range(self.opts.align_steps1), desc='Align Step 1', leave=False)
        for step in pbar:
            optimizer_align.zero_grad()
            latent_in = torch.cat([latent_align_1[:, :6, :], latent_1[:, 6:, :]], dim=1)
            down_seg, _ = self.create_down_seg(latent_in)

            loss_dict = {}
            ##### Cross Entropy Loss
            ce_loss = self.loss_builder.cross_entropy_loss(down_seg, target_mask)
            loss_dict["ce_loss"] = ce_loss.item()
            loss = ce_loss

            loss.backward()
            optimizer_align.step()

        intermediate_align, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False,
                                                   start_layer=0, end_layer=3)
        intermediate_align = intermediate_align.clone().detach()
  • setup_align_optimizer : 이미지 정렬을 위한 최적화기와 정렬 대상 이미지의 특성을 설정
  • create_down_seg : 주어진 latent code를 이용하여 이미지를 생성하고, 생성된 이미지의 세그멘테이션을 생성하는 과정을 수행

두 번째 정렬 단계

        optimizer_align, latent_align_2 = self.setup_align_optimizer(latent_W_path_2)

        with torch.no_grad():
            tmp_latent_in = torch.cat([latent_align_2[:, :6, :], latent_2[:, 6:, :]], dim=1)
            down_seg_tmp, I_Structure_Style_changed = self.create_down_seg(tmp_latent_in)

            current_mask_tmp = torch.argmax(down_seg_tmp, dim=1).long()
            HM_Structure = torch.where(current_mask_tmp == 10, torch.ones_like(current_mask_tmp),
                                       torch.zeros_like(current_mask_tmp))
            HM_Structure = F.interpolate(HM_Structure.float().unsqueeze(0), size=(256, 256), mode='nearest')

        pbar = tqdm(range(self.opts.align_steps2), desc='Align Step 2', leave=False)
        for step in pbar:
            optimizer_align.zero_grad()
            latent_in = torch.cat([latent_align_2[:, :6, :], latent_2[:, 6:, :]], dim=1)
            down_seg, gen_im = self.create_down_seg(latent_in)

            Current_Mask = torch.argmax(down_seg, dim=1).long()
            HM_G_512 = torch.where(Current_Mask == 10, torch.ones_like(Current_Mask),
                                   torch.zeros_like(Current_Mask)).float().unsqueeze(0)
            HM_G = F.interpolate(HM_G_512, size=(256, 256), mode='nearest')

            loss_dict = {}

            ########## Segmentation Loss
            ce_loss = self.loss_builder.cross_entropy_loss(down_seg, target_mask)
            loss_dict["ce_loss"] = ce_loss.item()
            loss = ce_loss

            #### Style Loss
            H1_region = self.downsample_256(I_Structure_Style_changed) * HM_Structure
            H2_region = self.downsample_256(gen_im) * HM_G
            style_loss = self.loss_builder.style_loss(H1_region, H2_region, mask1=HM_Structure, mask2=HM_G)

            loss_dict["style_loss"] = style_loss.item()
            loss += style_loss

            loss.backward()
            optimizer_align.step()

        latent_F_out_new, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False,
                                                 start_layer=0, end_layer=3)
        latent_F_out_new = latent_F_out_new.clone().detach()

마지막 정렬 단계

: 특정 영역에 대한 마스크(원래 이미지와 목표 이미지의 특정 영역)를 생성하고, 두 이미지의 latent code를 섞음. 이 과정을 두번 반복하여 최종적으로 두 이미지가 섞인 latent code로 결합하여 최종 이미지를 생성하고 이를 저장.

        free_mask = 1 - (1 - hair_mask1.unsqueeze(0)) * (1 - hair_mask_target)

        ##############################
        free_mask, _ = self.dilate_erosion(free_mask, device, dilate_erosion=smooth)
        ##############################

        free_mask_down_32 = F.interpolate(free_mask.float(), size=(32, 32), mode='bicubic')[0]
        interpolation_low = 1 - free_mask_down_32


        latent_F_mixed = intermediate_align + interpolation_low.unsqueeze(0) * (
                latent_F_1 - intermediate_align)

        if not align_more_region:
            free_mask = hair_mask_target
            ##########################
            _, free_mask = self.dilate_erosion(free_mask, device, dilate_erosion=smooth)
            ##########################
            free_mask_down_32 = F.interpolate(free_mask.float(), size=(32, 32), mode='bicubic')[0]
            interpolation_low = 1 - free_mask_down_32


        latent_F_mixed = latent_F_out_new + interpolation_low.unsqueeze(0) * (
                latent_F_mixed - latent_F_out_new)

        free_mask = F.interpolate((hair_mask2.unsqueeze(0) * hair_mask_target).float(), size=(256, 256), mode='nearest').cuda()
        ##########################
        _, free_mask = self.dilate_erosion(free_mask, device, dilate_erosion=smooth)
        ##########################
        free_mask_down_32 = F.interpolate(free_mask.float(), size=(32, 32), mode='bicubic')[0]
        interpolation_low = 1 - free_mask_down_32

        latent_F_mixed = latent_F_2 + interpolation_low.unsqueeze(0) * (
                latent_F_mixed - latent_F_2)

        gen_im, _ = self.net.generator([latent_1], input_is_latent=True, return_latents=False, start_layer=4,
                                       end_layer=8, layer_in=latent_F_mixed)
        self.save_align_results(im_name_1, im_name_2, sign, gen_im, latent_1, latent_F_mixed,
                                save_intermediate=save_intermediate)
  • dilate_erosion : 마스크의 크기를 조절하고 마스크를 부드럽게 만드는 함수
  • save_align_results : 생성된 특성을 이용하여 최종 이미지를 생성하고, 이를 저장

③ Blending

Blending.py
blend = Blending(args)

class Blending(nn.Module):

    def __init__(self, opts, net=None):
        super(Blending, self).__init__()
        self.opts = opts
        if not net:
            self.net = Net(self.opts)
        else:
            self.net = net

        self.load_segmentation_network()
        self.load_downsampling()
        self.setup_blend_loss_builder()


    def load_segmentation_network(self):
        self.seg = BiSeNet(n_classes=16)
        self.seg.to(self.opts.device)

        if not os.path.exists(self.opts.seg_ckpt):
            download_weight(self.opts.seg_ckpt)
        self.seg.load_state_dict(torch.load(self.opts.seg_ckpt))
        for param in self.seg.parameters():
            param.requires_grad = False
        self.seg.eval()
        
    def load_downsampling(self):
        self.downsample = BicubicDownSample(factor=self.opts.size // 512)
        self.downsample_256 = BicubicDownSample(factor=self.opts.size // 256)

    def setup_blend_loss_builder(self):
        self.loss_builder = BlendLossBuilder(self.opts)

blend.blend_images(im_path1, im_path2, im_path3, sign=args.sign)
: 두 이미지의 특성을 합성하고, 세 번째 이미지의 특성을 더하는 블렌딩 함수

    def blend_images(self, img_path1, img_path2, img_path3, sign='realistic'):

        device = self.opts.device
        output_dir = self.opts.output_dir

        im_name_1 = os.path.splitext(os.path.basename(img_path1))[0]
        im_name_2 = os.path.splitext(os.path.basename(img_path2))[0]
        im_name_3 = os.path.splitext(os.path.basename(img_path3))[0]

        I_1 = load_image(img_path1, downsample=True).to(device).unsqueeze(0)
        I_3 = load_image(img_path3, downsample=True).to(device).unsqueeze(0)

        HM_1D, _ = cuda_unsqueeze(dilate_erosion_mask_path(img_path1, self.seg), device)
        HM_3D, HM_3E = cuda_unsqueeze(dilate_erosion_mask_path(img_path3, self.seg), device)

        opt_blend, interpolation_latent = self.setup_blend_optimizer()
        latent_1, latent_F_mixed = load_FS_latent(os.path.join(output_dir, 'Align_{}'.format(sign),
                                            '{}_{}.npz'.format(im_name_1, im_name_3)),device)
        latent_3, _ = load_FS_latent(os.path.join(output_dir, 'FS',
                                            '{}.npz'.format(im_name_3)), device)

        with torch.no_grad():
            I_X, _ = self.net.generator([latent_1], input_is_latent=True, return_latents=False, start_layer=4,
                               end_layer=8, layer_in=latent_F_mixed)
            I_X_0_1 = (I_X + 1) / 2
            IM = (self.downsample(I_X_0_1) - seg_mean) / seg_std
            down_seg, _, _ = self.seg(IM)
            current_mask = torch.argmax(down_seg, dim=1).long().cpu().float()
            HM_X = torch.where(current_mask == 10, torch.ones_like(current_mask), torch.zeros_like(current_mask))
            HM_X = F.interpolate(HM_X.unsqueeze(0), size=(256, 256), mode='nearest').squeeze()
            HM_XD, _ = cuda_unsqueeze(dilate_erosion_mask_tensor(HM_X), device)
            target_mask = (1 - HM_1D) * (1 - HM_3D) * (1 - HM_XD)


        pbar = tqdm(range(self.opts.blend_steps), desc='Blend', leave=False)
        for step in pbar:

            opt_blend.zero_grad()

            latent_mixed = latent_1 + interpolation_latent.unsqueeze(0) * (latent_3 - latent_1)

            I_G, _ = self.net.generator([latent_mixed], input_is_latent=True, return_latents=False, start_layer=4,
                               end_layer=8, layer_in=latent_F_mixed)
            I_G_0_1 = (I_G + 1) / 2

            im_dict = {
                'gen_im': self.downsample_256(I_G),
                'im_1': I_1,
                'im_3': I_3,
                'mask_face': target_mask,
                'mask_hair': HM_3E
            }
            loss, loss_dic = self.loss_builder(**im_dict)

            loss.backward()
            opt_blend.step()

        ############## Load F code from  '{}_{}.npz'.format(im_name_1, im_name_2)
        _, latent_F_mixed = load_FS_latent(os.path.join(output_dir, 'Align_{}'.format(sign),
                                                        '{}_{}.npz'.format(im_name_1, im_name_2)), device)
        I_G, _ = self.net.generator([latent_mixed], input_is_latent=True, return_latents=False, start_layer=4,
                           end_layer=8, layer_in=latent_F_mixed)

        self.save_blend_results(im_name_1, im_name_2, im_name_3, sign, I_G, latent_mixed, latent_F_mixed)
profile
이세계 개발자입니다.

0개의 댓글