[PyTorch]Dataset & DataLoader

MA·2022년 7월 25일
0

PyTorch

목록 보기
6/6

Reference : https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

코드로 데이터 셈플들을 다루는 작업은 매우 복잡하고 유지하기가 어렵다. 그래서 파이토치는 데이터셋 코드를 트레이닝 코드와 구별하여 더 이해하기 쉽고 직관적으로 만들었다.

PyTorch provides two data primitives:
torch.utils.data.DataLoader
torch.utils.data.Dataset

Dataset은 샘플들과 그에 맞는 정답(labels)들을 저장한다.
DataLoader는 이러한 Datasetiterable하게 감싸는데, 쉽게 샘플들에 접근할 수 있도록 하기 위함이다.

간단하게 Fashion-MNIST 예제를 들어 설명하자면

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
	root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
	root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

Iterating and Visualizing the Datset

우리는 Datasets를 리스트 같이 만들 수 있다:training_data[index].

labels_map = {
	0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8,8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
	sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

torch.randint(low(int, optional), high(int), size(tuple))
Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive).

>>> torch.randint(3, 5, (3,))
tensor([4, 3, 4])

Creating a Custom Dataset for your files

custom Dataset은 무조건 세가지 function을 가지고 있어야 한다
\bullet __init__
\bullet __len__
\bullet __getitem__

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
	def __init__(self, annotations_file, img_dir, transform=None,
    target_transform=None):
    	self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        
    def __len__(self):
    	return len(self.img_labeles)
        
    def __getitem__(self, idx):
    	img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
        	image = self.transform(image)
        if self.target_transform:
        	label = self.target_transform(label)
        return image, label

__init__

__init__는 Dataset object를 인스턴스화 할 때 한번만 딱 실행된다. 여기서는 image를 담고있는 directory와 annotation파일, 그리고 두가지의 변환을 선언한다.

The labels.csv file looks like:

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

__len__

__len__은 샘플들의 숫자를 알려준다.

def __len__(self):
    return len(self.img_labels)

profile
급할수록 돌아가라

0개의 댓글