Dataset
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, text, labels):
self.labels = labels
self.data = text
def __len__(self):
return len(self.labels)
def __getitem__(self,idx):
label = self.labels[idx]
text = self.data[idx]
sample = {"Text" : text, "Class":label}
return sample
DataLoader
Data의 Batch를 생성해주는 클래스
학습직전 데이터의 변환을 책임지며, 대부분 이 부분에서 Tensor를 변환한다.
Tensor 변환과 Batch처리가 메인 업무
parameter
sampler / batch_sampler: 데이터를 어떻게 뽑을지 인덱스를 정해주는 기법
collate_fn : zero-padding이나 Variable 데이터 등 데이터 사이즈를 맞추기 위해 사용
# 글자가 긍정인지 부정인지에 대한 dataset
text = ["Happy","Amazing","Sad","Unhappy","Glum"]
labels = ["Positive","Positive","Negative","Negative","Negative"]
MyDataset = CustomDataset(text,labels)
# iterable한 객체이기 때문에 iter와 next 사용
MyDataLoader = DataLoader(MyDataset, batch_size=2, shuffle=True)
next(iter(MyDataLoader))
MyDataLoader = DataLoader(MyDataset, batch_size=3, shuffle=True)
for dataset in MyDataLoader:
print(dataset)