파이토치에서는 데이터를 좀 더 쉽게 다룰 수 있도록 유용한 도구로서 데이터셋(Dataset)과 데이터로더(DataLoader)를 제공한다.
이를 사용하면 미니 배치 학습, 데이터 셔플(shuffle), 병렬 처리까지 간단히 수행할 수 있다. 기본적인 사용 방법은 Dataset을 정의하고 -> 이를 DataLoader에 전달하는 것
from torch.utils.data import TensorDataset # 텐서데이터셋
from torch.utils.data import DataLoader # 데이터로더
TensorDataset은 기본적으로 텐서를 입력으로 받는다.
x_train = torch.FloatTensor([[73, 80, 75],
[93, 88, 93],
[89, 91, 90],
[96, 98, 100],
[73, 66, 70]])
y_train = torch.FloatTensor([[152], [185], [180], [196], [142]])
dataset = TensorDataset(x_train, y_train)
데이터로더는 기본적으로 2개의 인자를 입력받는다 (데이터셋, 미니 배치의 크기)
*이때 미니 배치의 크기는 통상적으로 2의 배수를 사용
+) 추가적으로 많이 사용되는 인자로 shuffle이 있다.
*shuffle=True를 선택하면 Epoch마다 데이터셋을 섞어서 데이터가 학습되는 순서를 바꾼다.
이를 코드로 표현하면 다음과 같다.
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 이후 코드에 dataloader는 다음과 같이 활용
for epoch in range(nb_epochs + 1):
for batch_idx, samples in enumerate(dataloader):
# 여기서 batch_idx 랑 samples는 임의로 지정한 변수명
# 다시 말해, enumerate 규칙에 의해 반환되는 것
enumerate 함수의 동작 원리는 다음과 같다: