PyTorch 기초 - Dataset 만들기

sp·2022년 2월 19일
1

PyTorch 기초

목록 보기
2/7
post-thumbnail

딥러닝 모델을 학습하기 전에 필요한 첫 번째 준비물은 데이터라고 할 수 있습니다. 주어진 데이터를 활용해 효과적으로 모델에 입력하기 위해서 PyTorch는 Dataset 클래스를 제공하고 있습니다. 이 포스트에서는 이를 활용해서 데이터셋을 만드는 방법에 대해 알아보겠습니다.

간단한 Dataset 만들기

먼저 이 포스트에서 사용할 모듈들을 불러오겠습니다.

import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

커스텀 데이터셋을 만들기 위해서 다음과 같이 클래스를 구현해보겠습니다.

class MyBaseDataset(Dataset):
    def __init__(self, x_data, y_data):
        self.x_data = x_data
        self.y_data = y_data
        
    def __getitem__(self, index): 
        return self.x_data[index], self.y_data[index]
        
    def __len__(self): 
        return self.x_data.shape[0]

이번에 구현한 MyBaseDatasetDataset 클래스를 상속하게 됩니다. 이 클래스를 상속했을 때, 구현해야 하는 메서드는 __init__, __getitem__, __len__입니다.

  • __init__ 메서드는 객체를 생성할 때 실행되는 메서드, 즉 생성자입니다. 여기에는 모델에 사용할 데이터를 담아두는 등 어떤 인덱스가 주어졌을 때 반환할 수 있게 만드는 초기 작업을 수행합니다.

  • __getitem__ 메서드는 어떤 인덱스가 주어졌을 때 해당되는 데이터를 반환하는 메서드입니다. numpy 배열이나 텐서 형식으로 반환합니다. 보통 입력과 출력을 튜플 형식으로 반환하게 됩니다.

  • __len__은 학습에 사용할 데이터의 총 개수라고 볼 수 있는데, 즉 얼마만큼의 인덱스를 사용할지를 반환하는 메서드입니다.

위 코드로 구현한 BaseDataset은 정형적인 데이터셋을 나타내고 있습니다. 가장 높은 차원에 인덱스로 접근해 데이터를 반환하게 됩니다. 이 클래스에 간단한 데이터를 넣은 예시를 보이겠습니다.

x_data = torch.arange(100)
y_data = x_data * x_data
dataset = MyBaseDataset(x_data, y_data)
print("dataset example: ", dataset[0], dataset[1], dataset[2])
print("dataset length:", len(dataset))
dataset example:  (tensor(0), tensor(0)) (tensor(1), tensor(1)) (tensor(2), tensor(4))
dataset length: 100

이 코드는 y=x2y=x^2에 대해 xx가 0에서 99까지 주어진 데이터셋을 만든 것으로 볼 수 있습니다. 만들어진 데이터셋은 구현한 메서드를 기반으로 리스트처럼 인덱스로 접근할 수 있고, 길이를 알 수 있습니다.

인덱스에 접근할 때 데이터 불러오기

위 예시와 같이 모든 데이터를 Dataset에 저장할 수 있다면 구현하기도 편할 것입니다. 그러나 ImageNet과 같이 1400만개가 넘는 데이터셋을 메모리에 모두 불러오는 것은 사실상 불가능할 것입니다. 그래서 일반적으로는 생성자에서 모든 데이터를 불러오지는 않고, 불러올 이미지들의 경로를 저장하는 방식을 사용합니다. 그리고 데이터셋에 인덱스로 접근할 때 경로에 존재하는 데이터를 불러와 메모리에 적재하면 이 문제를 해결할 수 있습니다. 물론 보조기억장치에서 데이터를 불러올 때의 지연이 나타날 수 있지만 보통은 무시할 수 있습니다. 참고로 데이터들을 한 파일로 만들어서 활용하고 싶다면, HDF5 파일을 활용해 볼 수 있습니다. 전처리로 HDF5 파일을 만든 후, 학습에 이 파일을 불러와 사용하는 방법을 사용할 수 있습니다.

그러면 인덱스에 접근할 때 파일 경로로부터 데이터를 불러오는 Dataset 클래스를 구현해보겠습니다.

class DogCatDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.image_path_list = os.listdir(data_dir)
        self.transform = transforms.ToTensor()

    def __getitem__(self, index):
        image_path = os.path.join(self.data_dir, self.image_path_list[index])

        x_data = Image.open(image_path)
        x_data = self.transform(x_data)
        y_data = 1 if "dog" in self.image_path_list[index] else 0

        return x_data, y_data

    def __len__(self):
        return len(self.image_path_list)

이 코드에서는 개(Dog)와 고양이(Cat)을 구별하기 위한 데이터셋으로 볼 수 있습니다. 생성자에서는 인수로 넘어온 data_dir의 파일들을 데이터로 사용하기 위해 리스트로 저장하고, 이미지를 텐서로 변환하는 transform을 사용합니다. (관련 내용은 잠시 뒤에 다루겠습니다.) __getitem__로 인덱스가 넘어오면, 경로에 해당하는 이미지를 불러와 텐서로 변환해 x_data로 반환합니다. 타겟 데이터(y_data)는 파일 이름으로부터 추출하게 됩니다. 이 예시에서는 파일들이 dog.10.jpg, cat.3.jpg 등으로 구분되어 있다고 가정해 볼 때, 개일 경우에는 dog 문자열이 포함되어 있고 고양이일 경우에는 이 문자열이 포함되어 있지 않다고 바꾸어 볼 수 있습니다. 이를 if else문으로 판단해 반환하게 됩니다.

불러온 이미지를 변형하기

여기서 불러온 이미지들을 바로 텐서로 변환해서 반환할 수도 있지만, 이미지의 크기가 일정하지 않거나, 데이터 증강(Data augmentation)을 하고 싶을 때가 있습니다. 이 때에는 torchvision에서 제공하는 transforms 모듈들을 사용하면 됩니다.

간단한 사용방법을 위 코드의 일부에서 가져와서 살펴보겠습니다.

transform = transforms.ToTensor()
x_data = transform(x_data)

가져올 변형(여기서는 ToTensor)을 가져와서 변수로 선언한 다음에, 변형할 데이터를 매개변수로 넣는 방식으로 진행합니다. 그래서 생성자에서 transform을 인스턴스 변수로 선언하고 __getitem__에서 이를 활용하는 방식으로 사용하는 것입니다.

이 변형 중에서 주요한 몇몇 모듈들을 살펴보겠습니다.

torchvision.transforms.Resize(size, interpolation='bilinear', ...)
torchvision.transforms.CenterCrop(size)
torchvision.transforms.RandomHorizontalFlip(p=0.5)
torchvision.transforms.RandomVerticalFlip(p=0.5)
torchvision.transforms.RandomRotation(degrees, interpolation='nearest', ...)
torchvision.transforms.RandomCrop(size, padding=None, ...)

위 변형들은 이미지 전처리 또는 데이터 증강에 주로 사용됩니다. 크기 조절, 뒤집기, 회전하기 등의 작업을 수행할 수 있습니다.

torchvision.transforms.ToTensor()
torchvision.transforms.ToPILImage(mode=None)

텐서와 PIL 이미지로 상호 변환하는 변형들입니다. ToTensor은 불러온 이미지를 텐서로 불러올 때, ToPILImage는 모델의 결과로 나온 영상을 저장하기 위해 주로 사용됩니다. 여기서 ToTensor로 텐서로 변환한 이미지의 값은 [0, 1]의 범위를 가집니다.

torchvision.transforms.Normalize(mean, std, inplace=False)

Normalize는 텐서를 정규화하는 모듈입니다. mean과 std를 매개변수로 입력하게 되는데, 들어온 텐서 xx에 대해 (xmean)/std(x-mean) / std를 수행한 텐서를 리턴받습니다. 보통 이 모듈을 사용할 때에는 [0, 1][-1, 1] 사이의 변환과 PyTorch에서 제공하는 이미 학습된(pretrained) 모델을 파인튜닝(fine-tuning)하기 위해 RGB 데이터 분포에 맞춰주는 것입니다. RGB 영상에 대해 각각에 해당하는 예시를 살펴보게습니다.

transform_1 = transforms.Normalize([0.5, 0.5, 0.5], [[0.5, 0.5, 0.5])
transform_2 = transforms.Normalize([1, 1, 1], [[2, 2, 2])
transform_3 = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

transform_1[0, 1]에서 [-1, 1]로, transform_2은 그 반대, transform_3은 파인튜닝하기 위해 사용되는 값들이라고 볼 수 있습니다.

마지막으로, 2개 이상의 변형들을 한번에 적용하기 위한 모듈을 살펴보겠습니다.

torchvision.transforms.Compose(transforms)

여기서 매개변수로 들어가는 transforms에 사용할 변형들을 리스트로 넣어주면 순서대로 수행해 반환해줍니다. 일반적으로 사용할 수 있는 예시는 다음과 같습니다.

transform_compose = transforms.Compose([
    transforms.RandomResizedCrop(input_size),
    transforms.RandomRotation(degrees=(-30, 30))
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

그 외 모든 transforms 모듈들에 대한 정보를 찾을 때에는 공식 링크를 참조하시면 됩니다.

0개의 댓글