pytorch를 사용하여 DL 모델을 학습할 때, data의 기본 원소가 되는 두 클래스:
Dataset과 DataLoader class를 알아보도록 하겠습니다!
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
공식 홈페이지에도 아주 잘 나와있습니다 :)
Dataset stores the samples and their corresponding labels
✔ Dataset class는 모델 학습시 사용할 data와 label을 저장합니다.
All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.
✔ 미리 작성된 torch.utils.data.Dataset class를 상속(서브클래싱)하여 사용합니다.
✔ __init__() 메서드로 인스턴스 생성시 필요한 초기값들을 받습니다. (이미지 리스트, transform 함수 등)
✔ __getitem__() 메서드의 리턴값이 추후 학습에서 사용할 값이 되도록 오버라이딩합니다.
import torch
from PIL import Image
class MyDataset(torch.utils.data.Dataset):
"""
Attributes
----------
img_list : 리스트
이미지의 경로를 저장한 리스트
label_list : 리스트
label의 경로를 저장한 리스트
phase : 'train' or 'val'
학습 또는 테스트 여부 결정
transform : object
전처리 클래스의 인스턴스
"""
def __init__(self, img_list, label_list, phase, transform):
self.img_list = img_list
self.label_list = label_list
self.phase = phase # train 또는 val을 지정
self.transform = transform # 이미지의 변형
def __len__(self):
'''이미지의 갯수를 반환'''
return len(self.img_list)
def __getitem__(self, index):
'''
전처리한 이미지 및 라벨 return
'''
image_path = self.file_list[index]
img = Image.open(image_path)
transformed_img = self.transform(img, self.phase)
label = self.label_list[index]
return transformed_img, label
😁 어때요? 정말 간단하죠!
그리고, transform 인자에 들어갈 인스턴스의 틀이 되는 클래스는 보통 __call__() 메서드를 정의하는 방법으로 많이 사용됩니다.
✔ __call__() : 클래스 인스턴스를 생성 후, () 명령어를 사용하면 실행되는 함수입니다.
from torchvision import models, transforms
class MyTransform():
"""
Attributes
----------
resize : int
Transform 수행 후 변경될 width / height 값.
mean : (R, G, B)
각 색상 채널의 평균값.
std : (R, G, B)
각 색상 채널의 표준 편차.
"""
def __init__(self, resize, mean, std):
self.data_transform = {
'train': transforms.Compose([
transforms.RandomResizedCrop(
resize, scale=(0.5, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), # 텐서로 변환
transforms.Normalize(mean, std) # 표준화
]),
'val': transforms.Compose([
transforms.Resize(resize),
transforms.ToTensor(), # 텐서로 변환
transforms.Normalize(mean, std) # 표준화
])
}
def __call__(self, img, phase='train'):
"""
Parameters
----------
phase : 'train' or 'val'
전처리 모드를 지정.
"""
return self.data_transform[phase](img)
예시에서는 학습을 위해 RandomResizedCrop() / RandomHorizontalFlip() 만 적용을 하였지만, 더 다양한 Augmentation을 적용할 수도 있습니다. 😁
✔ 실제로 학습을 위한 인스턴스는 아래와 같이 생성됩니다 😎
size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
train_dataset = MyDataset(file_list=train_img_list, phase="train", transform=MyTransform(
size, mean, std)))
val_dataset = MyDataset(file_list=val_img_list, phase="val", transform=MyTransform(
size, mean, std)))
✔ train_dataset / val_dataset은 generator로써, 추후 '실제로 반복이 수행될 때' 메모리를 할당하여 작업을 수행합니다.
✔ D/L 모델 학습을 수행할 때, generator를 사용하지 않으면 OOM이 종종 발생하고는 하여, torch에서는 더욱 간단히 쓸 수 있도록 틀을 제공한 것으로 이해되네요 😁
DataLoader wraps an iterable around the Dataset to enable easy access to the samples
✔ DataLoader class는 Dataset class로 정의된 데이터 뭉치를 쉽게 샘플 단위로 가져올 수 있게 합니다.
✔ 즉, 보통 D/L 모델 학습을 위한 mini-batch 학습 단위로 가져오게 해주는 역할을 수행합니다.
🤦♀️ 그리고 코드는 넘나 간단.. :)
# DataLoader를 만든다
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False)
# 사전 객체에 정리
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}
✔ batch size 단위로 가져오게 되어, train_dataloader의 iteration을 진행하면, [32, 224, 224, 3]의 shape을 가진 Tensor를 얻을 수 있습니다.
✨ 그리고 학습 수행하면 오케이!
👍 추가) Object Detection Task 등, batch에 포함된 데이터의 크기가 각각 다를 때가 있습니다. (각 이미지에 라벨이 몇개씩 있는지 알 수 없으므로) 이 때, DataLoader 인스턴스를 생성할 시 collate_fn 인자를 설정하여 해결할 수 있습니다.
https://pytorch.org/docs/stable/data.html?highlight=collate_fn
다음 글은 간단히 Fine Tuning을 진행하는 글을 작성하도록 하겠습니다 :>