[AI] Custom Dataset & DataLoader

JAsmine_log·2024년 8월 24일
0

Pytorch

Custom Dataset

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

init

__init__ 함수는 Dataset 객체를 인스턴스화할 때 한 번 실행되는 함수이다. 이 함수는 이미지가 포함된 디렉터리, 주석 파일(annotations file), 그리고 변환(transform)을 초기화한다.

  • labels.csv :
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

len

__len__은 Dataset의 길이를 나나내는 함수이다.

Example:

def __len__(self):
    return len(self.img_labels)

getitem

__getitem__은 주어진 인덱스(idx)에 해당하는 데이터셋의 샘플을 로드하고 반환하는 함수이다. 이 함수는 인덱스를 기반으로 이미지의 위치를 디스크에서 식별하고, read_image를 사용하여 이를 텐서로 변환하며, self.img_labels에 있는 CSV 데이터에서 해당하는 레이블을 검색한다. 이후, 변환 함수가 적용 가능한 경우에는 이를 호출하고, 텐서 이미지와 해당 레이블을 튜플 형태로 반환한다.

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label

Preparing your data for training with DataLoaders

Dataset은 데이터셋의 특징과 레이블을 한 번에 하나씩 가져온다. 모델을 학습할 때는 일반적으로 샘플을 “미니배치(minibatches)”로 전달하고, 모델의 과적합을 줄이기 위해 매 에포크(epoch)마다 데이터를 다시 섞으며, 데이터 검색 속도를 높이기 위해 Python의 멀티프로세싱을 사용한다.
DataLoader는 이러한 복잡성을 간단한 API로 추상화하여 제공하는 반복 가능한 객체이다.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

Iterate through the DataLoader

데이터셋은 DataLoader에 로드하고, 필요에 따라 데이터셋을 반복(iterate)할 수 있다. 아래의 각 반복(iteration)은 배치 크기(batch_size)=64의 train_features와 train_labels를 반환한다. shuffle=True로 지정했기 때문에, 모든 배치를 반복한 후에는 데이터가 섞인다. 데이터 로딩 순서에 대한 더 세밀한 제어가 필요하다면 Samplers를 참조하여 사용할 수 있다.

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

Refernece
[1] pytorch, Datasets & DataLoaders, https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files

profile
Everyday Research & Development

0개의 댓글