ex)
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
# 초기 데이터 생성 방법 지정
def __init__(self, data, labels):
self.data = data
self.data = text
# 데이터 길이
def __len__(self):
return len(self.labels)
# index 값이 주어졌을 때 반환되는 데이터
def __getitem__(self, index):
data = self.data[index]
label = self.labels[index]
return data, label
ex)
text = ['Mad', 'Joy', 'Glum', 'Happy']
labels = ['Negative', 'Positive', 'Negative', 'Positive']
MyDataset = CustomDataset(text, labels)
MyDataLoader = DataLoader(MyDataset, batch_size=3, shuffle=True) # iterable generator
# next(iter(MyDataLoader)) # 다음 이터 찍어보기
for dataset in MyDataLoader:
print(dataset)
# {'Text': ['Mad', 'Happy', 'Glum'], 'Class': ['Negative', 'Positive', 'Negative']}
# {'Text': ['Joy'], 'Class' : ['Positive']}
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
batch_size
: 배치의 크기
예를 들어, 데이터가 100개 있고 batch가 20이면 5번의 iteration이 지나고 모든 데이터를 본다.
shuffle
: 데이터를 섞느냐
sampler
: 데이터를 어떻게 뽑을지 index를 조정하는 법
num_workers
: 데이터 로딩에 사용하는 subprocess 개수.
많다고 좋은 게 X. -> CPU와 GPU 사이에 너무 많은 교류는 병목을 발생시킴.
collate_fn
: variable length를 처리하기 위해 많이 씀. 길이가 다른 데이터를 padding 할 때 여기서 정의
[[Data, Label], [Data, Label]] --> [[Data, Data], [Label, Label]]
drop_last
: 데이터가 batch로 나누어 떨어지지 않고 나머지가 생길 때 마지막 배치의 길이가 달라 loss를 구하기 번거롭거나, batch size에 의존도가 높은 함수를 사용할 때 마지막 배치를 사용하지 않는 것.