torch.utils.data.Dataset
으로 불러올 수 있는 추상 클래스입니다.
추상 클래스이므로,
1) 반드시 사전에 모듈(torch.utils.data
) import가 필요하며,
2) 자식 클래스에서 구현해야 하는 추상 메소드가 존재합니다.
Dataset에서 구현해야 하는 메소드는 __init__
, __getitem__
, __len__
이 있습니다. 커스텀 데이터셋 구축 시에는 해당 메소드들을 만들어줘야 합니다.
기본적으로 데이터를 전처리하고 다루기 편한 형태로 묶어주는 일을 한다고 생각하면 됩니다. 예를 들어, 학습을 위한 raw image가 5000장 있고 각 이미지의 라벨 정보가 json파일에 저장돼 있다고 가정해 보겠습니다.
이 경우, 미니배치 또는 배치 학습을 위해 매번 json 파일을 읽는 것은 비효율적입니다. 따라서 커스텀 데이터셋을 구축하는 과정에서 딕셔너리 등의 형태로 이미지 파일과 라벨을 매칭해 두면 이후에도 쉽게 가져다 쓸 수 있습니다.
Dataset에서 전처리를 맡는 부분은 주로 __init__
이고, __len__
과 __getitem__
은 반복 사용을 위한 인덱싱 작업에 사용됩니다.
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'label'])
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
위 코드는 공식 튜토리얼의 커스텀 데이터셋 예시 코드입니다.
어떤 return 값을 줄 것인지 필요에 따라 구현하면 됩니다.
torch.utils.data.DataLoader
로 불러올 수 있는 클래스입니다.
주로 Dataset으로 구축한 데이터를 batch 단위로 나누는 역할입니다.
주어진 dataset과 sampler를 감싸서 iterable한 객체로 반환하는 일을 합니다. map-style
, iterable-style
두 가지 데이터셋을 모두 받을 수 있습니다. 자세한 용법은 공식 문서를 참고하면 됩니다.
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
train_feature, train_labels = next(iter(train_dataloader))
print(train_feature.size(), train_labels.size())
>>> torch.Size([64, 1, 28, 28]) torch.Size([64])
예시 코드를 보겠습니다.
batch_size = 64이기 때문에 전체 데이터를 64개씩 자릅니다.
Dataset 함수에서 image, label 두 개를 리턴하므로, train_feature.size()
와 train_labels.size()
로 받은 텐서의 크기는 각각 (64, 1, 28, 28)과 (64,)가 됩니다.
참고문헌
1) 파이토치 한국어 튜토리얼 (https://tutorials.pytorch.kr/beginner/basics/data_tutorial.html)
2) 파이토치 공식 문서 (https://pytorch.org/docs/stable/data.html?highlight=dataset#torch.utils.data.Dataset)