오늘은 저번 Computer Vision - Image Segmentation(1)에서 소개한 공식 사이트 Github에 올라온 예제 코드를 작서해볼 생각입니다. 코드 주소는 여기를 클릭해 주세요.
!pip install segmentation-models-pytorch
!pip install pytorch-lightning==1.5.4
import os
import torch
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import segmentation_models_pytorch as smp # 저희가 사용할 모델입니다.
from pprint import pprint
from torch.utils.data import DataLoader
segmentation이라 image와 mask가 필요합니다. 저장된 방식은 여러분에게 주어진 상황에 따라 달라지니 이번 코드는 참조만 해주시길 바랍니다.
예제로 나온 코드는 segmentation model에서 제공하는 SimpleOxfordPetDataset을 가져왔습니다.
from segmentation_models_pytorch.datasets import SimpleOxfordPetDataset
# 데이터를 지정한 위치에 저장합니다. root + '/datafold'로 저장이 되니 root를 변경하세요
root = "."
SimpleOxfordPetDataset.download(root)
# 훈련 세트, 검증 세트, 테스트 세트로 나눠줍니다.
train_dataset = SimpleOxfordPetDataset(root, "train")
valid_dataset = SimpleOxfordPetDataset(root, "valid")
test_dataset = SimpleOxfordPetDataset(root, "test")
# 서로 겹치는 데이터가 없는지 검증합니다. 이런 방식도 있구나 하고 잘 활용해 보세요. 필자도 배워갑니다.
assert set(test_dataset.filenames).isdisjoint(set(train_dataset.filenames))
assert set(test_dataset.filenames).isdisjoint(set(valid_dataset.filenames))
assert set(train_dataset.filenames).isdisjoint(set(valid_dataset.filenames))
print(f"Train size: {len(train_dataset)}")
print(f"Valid size: {len(valid_dataset)}")
print(f"Test size: {len(test_dataset)}")
# 할당된 cpu만큼 사용해 각자 dataloader을 만듭니다. 잘 모르겠고, 혹여나 오류가 생긴다면 num_workers=0 으로 변경하면 자동으로 정합니다.
n_cpu = os.cpu_count()
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=n_cpu)
valid_dataloader = DataLoader(valid_dataset, batch_size=16, shuffle=False, num_workers=n_cpu)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=n_cpu)
훈련, 검증, 테스트 세트와 로더를 만들었다면 실제로 잘 나뉘어 졌는지 화면상으로 봅시다.
이 부분은 모델 학습에 필요한 부분이 아니라 뛰어 넘어가도 됩니다.
sample = train_dataset[0]
plt.subplot(1,2,1)
plt.imshow(sample["image"].transpose(1, 2, 0)) # for visualization we have to transpose back to HWC
plt.subplot(1,2,2)
plt.imshow(sample["mask"].squeeze()) # for visualization we have to remove 3rd dimension of mask
plt.show()
sample = valid_dataset[0]
plt.subplot(1,2,1)
plt.imshow(sample["image"].transpose(1, 2, 0)) # for visualization we have to transpose back to HWC
plt.subplot(1,2,2)
plt.imshow(sample["mask"].squeeze()) # for visualization we have to remove 3rd dimension of mask
plt.show()
sample = test_dataset[0]
plt.subplot(1,2,1)
plt.imshow(sample["image"].transpose(1, 2, 0)) # for visualization we have to transpose back to HWC
plt.subplot(1,2,2)
plt.imshow(sample["mask"].squeeze()) # for visualization we have to remove 3rd dimension of mask
plt.show()