Data loader

Sukhun-Net·2024년 6월 21일

파이토치에서는 데이터를 좀 더 쉽게 다룰 수 있도록 유용한 도구로서 데이터셋(Dataset)과 데이터로더(DataLoader)를 제공한다.

이를 사용하면 미니 배치 학습, 데이터 셔플(shuffle), 병렬 처리까지 간단히 수행할 수 있다. 기본적인 사용 방법은 Dataset을 정의하고 -> 이를 DataLoader에 전달하는 것


from torch.utils.data import TensorDataset # 텐서데이터셋
from torch.utils.data import DataLoader # 데이터로더

TensorDataset은 기본적으로 텐서를 입력으로 받는다.

  1. 먼저 텐서 형태로 데이터를 정의


x_train  =  torch.FloatTensor([[73,  80,  75], 
                               [93,  88,  93], 
                               [89,  91,  90], 
                               [96,  98,  100],   
                               [73,  66,  70]])  

y_train  =  torch.FloatTensor([[152],  [185],  [180],  [196],  [142]])
  1. 정의한 데이터를 TensorDataset의 입력으로 사용하고 dataset으로 저장
dataset = TensorDataset(x_train, y_train)
  1. 위와 같이, 파이토치의 데이터셋을 만들었다면 데이터로더 사용 가능

DataLoader

데이터로더는 기본적으로 2개의 인자를 입력받는다 (데이터셋, 미니 배치의 크기)
*이때 미니 배치의 크기는 통상적으로 2의 배수를 사용

+) 추가적으로 많이 사용되는 인자로 shuffle이 있다.
*shuffle=True를 선택하면 Epoch마다 데이터셋을 섞어서 데이터가 학습되는 순서를 바꾼다.

이를 코드로 표현하면 다음과 같다.

dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 이후 코드에 dataloader는 다음과 같이 활용 

for epoch in range(nb_epochs + 1):
  for batch_idx, samples in enumerate(dataloader):
  
# 여기서 batch_idx 랑 samples는 임의로 지정한 변수명 
# 다시 말해, enumerate 규칙에 의해 반환되는 것 
    

enumerate 함수의 동작 원리는 다음과 같다:

  • 첫 번째 반환 값은 반복의 인덱스이다.
  • 두 번째 반환 값은 반복 가능한 객체(dataloader)에서 가져온 실제 데이터이다.
profile
Data Scientist (Computer Vision, Multimodal)

0개의 댓글