PyTorch 실전 - CycleGAN 테스트 과정 구현하기

sp·2022년 4월 6일
0

PyTorch 실전

목록 보기
4/4
post-thumbnail

이전 포스트에 구현한 대로 학습 데이터셋으로 CycleGAN 모델을 학습했으면 그 다음으로 테스트 과정을 구현하는 것이 마지막 과정이라고 할 수 있겠습니다. 이 포스트에서는 학습된 가중치를 가져와서 테스트 이미지에 적용해 보고, 모델을 통과한 이미지가 잘 변환이 되었는지 살펴보도록 하겠습니다.

학습 단계와는 다르게 모델을 통과하는 것이 전부이므로, 상대적으로 코드가 간단한 편입니다. 이를 순서대로 살펴보겠습니다.

추론 코드 구현하기

import argparse
import os
import glob
import torch
from PIL import Image

import torchvision.transforms as transforms

from models import Generator


parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, default="datasets/apple2orange")
parser.add_argument("--checkpoint_path", type=str)
parser.add_argument("--size", type=int, default=256)

args = parser.parse_args()


def test():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # 1, Model
    num_blocks = 6 if args.size <= 256 else 8
    netG_A2B = Generator(num_blocks).to(device)
    netG_B2A = Generator(num_blocks).to(device)

    # 2
    checkpoint = torch.load(args.checkpoint_path, map_location=device)
    netG_A2B.load_state_dict(checkpoint["netG_A2B_state_dict"])
    netG_B2A.load_state_dict(checkpoint["netG_B2A_state_dict"])

    netG_A2B.eval()
    netG_B2A.eval()

    # 3, Dataset
    transform_to_tensor = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
        ]
    )
    # 4
    transform_to_image = transforms.Compose(
        [
            transforms.Normalize(mean=(-1, -1, -1), std=(2, 2, 2)),
            transforms.ToPILImage(),
        ]
    )

    # 5
    dataset_name = os.path.basename(args.dataset_path)

    os.makedirs(f"results/{dataset_name}/testA", exist_ok=True)
    os.makedirs(f"results/{dataset_name}/testB", exist_ok=True)

    test_list = [["testA", netG_A2B], ["testB", netG_B2A]]

    # 6
    for folder_name, model in test_list:
        print(folder_name)
        image_path_list = sorted(
            glob.glob(os.path.join(args.dataset_path, folder_name) + "/*")
        )
        for idx, image_path in enumerate(image_path_list):
            image_name = os.path.basename(image_path)
            print(f"{idx}/{len(image_path_list)} {image_name}")

            # 7
            image = Image.open(image_path)
            image = transform_to_tensor(image).unsqueeze(0)

            output = model(image)

            # 8
            output = transform_to_image(output.squeeze())
            output.save(os.path.join("results", dataset_name, folder_name, image_name))


if __name__ == "__main__":
    test()
  1. 학습 과정에서 checkpoint에 가중치만 저장했기 때문에 생성자를 선언하는 과정이 필요합니다. 여기서 가중치의 수를 맞추어 주기 위해 block 수를 포함해 초기화하도록 합니다.

  2. checkpoint로부터 가중치를 불러들입니다. 여기서 추론은 CPU로도 진행할 수 있음을 가정해서, map_location으로 장치 간 불러오기를 지원하도록 합니다.

  3. 생성자에 넣기 위해 이미지로부터 텐서로 바꾸고, 정규화를 수행합니다. 여기서는 0에서 1 사이의 값이 -1에서 1로 바뀝니다.

  4. 모델에 통과하고, 이미지로 변환하기 위한 변형을 수행합니다. 여기서 정규화는 -1에서 1 사이의 값이 0에서 1로 변환됩니다.

  5. results 폴더 내에 데이터셋 이름 내 testA, testB 폴더 안에 각 이미지가 들어가도록 합니다. 이는 정규 데이터셋의 구조를 그대로 따르도록 구현합니다. 그리고 testA는 style B로 변환하기 위해 netG_A2B를 사용하고, testB는 netG_B2A를 사용하도록 합니다.

  6. testA, testB 폴더에 대해 테스트를 수행하도록 합니다. 각 폴더에 대해 이미지들의 경로들을 image_path_list로 만들고, 각 경로에 해당하는 이미지에 대해 추론을 진행하도록 합니다.

  7. 이미지를 불러와서 모델의 입력에 들어가기 위한 전처리를 수행합니다. unsqueeze는 일반적인 연산들이 배치 단위로 들어가는 것을 고려하기 위해 넣습니다. 예를 들어 [3, 256, 256]의 크기의 텐서에서 배치 차원을 추가해 [1, 3, 256, 256]의 형태로 만드는 것으로 볼 수 있습니다.

  8. 모델에 통과된 이미지에서 제일 앞단 배치를 빼고, 후처리를 통해 이미지로 만듭니다. 그리고 이를 해당하는 폴더에 같은 이름으로 저장하도록 합니다.

추론하고 결과 확인하기

코드를 작성했으면 다음과 같이 실행하면 됩니다.

python3 test.py --dataset_path apple2orange  --checkpoint_path checkpoint/apple2orange/500.pth

dataset_path로는 학습때와 동일하게 작성하면 되고, checkpoint_path는 학습 때 저장된 경로로 입력하면 됩니다. 500 미만 에포크만큼 학습된 checkpoint도 명시해서 비교할 수도 있습니다.

입력 이미지와 모델을 통과된 이미지 쌍들 중 일부를 나타내면 다음과 같습니다.

두 스타일 간 이미지가 그럭저럭 잘 변환된 것을 확인할 수 있습니다.

0개의 댓글