Reference : https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
코드로 데이터 셈플들을 다루는 작업은 매우 복잡하고 유지하기가 어렵다. 그래서 파이토치는 데이터셋 코드를 트레이닝 코드와 구별하여 더 이해하기 쉽고 직관적으로 만들었다.
PyTorch provides two data primitives:
torch.utils.data.DataLoader
torch.utils.data.Dataset
Dataset
은 샘플들과 그에 맞는 정답(labels)들을 저장한다.
DataLoader
는 이러한 Dataset
을 iterable
하게 감싸는데, 쉽게 샘플들에 접근할 수 있도록 하기 위함이다.
간단하게 Fashion-MNIST 예제를 들어 설명하자면
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
Iterating and Visualizing the Datset
우리는 Datasets
를 리스트 같이 만들 수 있다:training_data[index]
.
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8,8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
torch.randint(low(int, optional), high(int), size(tuple))
Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive).
>>> torch.randint(3, 5, (3,))
tensor([4, 3, 4])
Creating a Custom Dataset for your files
custom Dataset
은 무조건 세가지 function을 가지고 있어야 한다
__init__
__len__
__getitem__
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None,
target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labeles)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
__init__
__init__는 Dataset object
를 인스턴스화 할 때 한번만 딱 실행된다. 여기서는 image를 담고있는 directory와 annotation파일, 그리고 두가지의 변환을 선언한다.
The labels.csv file looks like:
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
__len__
__len__은 샘플들의 숫자를 알려준다.
def __len__(self):
return len(self.img_labels)