PyTorch에서 Dataset, DataLoader 를 이용하여 데이터를 정리하는 코드를 볼 수가 있다. 원래의 데이터를 그대로 이용해도 될텐데 Dataset, DataLoader를 사용하는 이유가 무엇일까?
아래의 Reference에 잘 설명되어 있는데, 요약하면
Dataset은 용량이 큰 데이터를 이용하여 모델을 학습시킬 때 많은 양의 데이터를 한꺼번에 불러오면 ram에 무리가 가므로, 필요한 만큼만 불러서 쓰는 custom set을 만들어 효율적으로 관리할 수 있도록 한다. DataLoader은 Dataset의 인스턴스를 감싸서 batch size로 데이터를 로드하고 데이터셋을 섞는 등의 작업을 수행한다.그럼 이제 Dataset과 DataLoader를 어떻게 사용하는지 살펴보자.
from torch.utils.data import Dataset, DataLoader
# How to use Dataset
class MyDataset(Dataset): # CustomDataset class
def __init__(self, x_tensor, y_tensor): # 데이터 초기화
self.x = x_tensor
self.y = y_tensor
def __len__(self): # 데이터 개수 반환
return len(self.x)
def __getitem__(self, idx): # 특정 인덱스의 데이터 반환
return self.x[idx], self.y[idx]
train_dataset = MyDataset(x_train, y_train)
test_dataset = MyDataset(x_test, y_test)
# How to use DataLoader
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)