[Boostcamp 2주차] PyTorch 구조 학습하기

yoonene·2022년 1월 28일
0

Boostcamp AI Tech

목록 보기
8/27

Dataset & DataLoader

📌 핵심 정리

  • 모델에 데이터를 Feeding 하는 법
  • Dataset 클래스는 기본적으로 __init__(), __len__(), __getitem__() 으로 구성
  • DataLoader는 데이터를 Tensor로 변환 + Batch 처리가 주 업무

Dataset 클래스

ex)
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
	# 초기 데이터 생성 방법 지정
    def __init__(self, data, labels):
	self.data = data
        self.data = text
    # 데이터 길이
    def __len__(self):
        return len(self.labels)
    
    # index 값이 주어졌을 때 반환되는 데이터
    def __getitem__(self, index):
        data = self.data[index]
        label = self.labels[index]
        return data, label
  • 기본 구성
    • __init__()
    • __len__()
    • __getitem__()
  • 데이터 입력 형태를 정의하는 클래스
  • 하나의 데이터에 어떻게 적용하는가 같은 데이터 입력 방식의 표준화
  • Image, Text 등 데이터 형태에 따라 다른 정의
  • 데이터 처리를 무조건 데이터를 생성하는 dataset 클래스에 처리할 필요 없음.
    CPU에서 Tensor 변환과 GPU 학습이 병렬로 처리될 수 있기 때문.
    예를 들어 transfrom 같은 함수를 포함해도 되지만 나중에 모델 넣을 때 해줘도 됨.
  • dataset 에 대한 표준화된 처리방법이 처리되어야 다른 사람이 잘 이해할 수 있음.
  • HuggingFace (NLP) 등 표준화된 라이브러리를 사용.

DataLoader 클래스

ex)
text = ['Mad', 'Joy', 'Glum', 'Happy']
labels = ['Negative', 'Positive', 'Negative', 'Positive']
MyDataset = CustomDataset(text, labels)

MyDataLoader = DataLoader(MyDataset, batch_size=3, shuffle=True) # iterable generator
# next(iter(MyDataLoader)) # 다음 이터 찍어보기
for dataset in MyDataLoader:
   print(dataset)
# {'Text': ['Mad', 'Happy', 'Glum'], 'Class': ['Negative', 'Positive', 'Negative']}
# {'Text': ['Joy'], 'Class' : ['Positive']}

DataLoader의 parameters

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
                      batch_sampler=None, num_workers=0, collate_fn=None,
                      pin_memory=False, drop_last=False, timeout=0,
                      worker_init_fn=None, *, prefetch_factor=2,
                      persistent_workers=False)

  • batch_size
    : 배치의 크기
    예를 들어, 데이터가 100개 있고 batch가 20이면 5번의 iteration이 지나고 모든 데이터를 본다.

  • shuffle
    : 데이터를 섞느냐

  • sampler
    : 데이터를 어떻게 뽑을지 index를 조정하는 법

  • num_workers
    : 데이터 로딩에 사용하는 subprocess 개수.
    많다고 좋은 게 X. -> CPU와 GPU 사이에 너무 많은 교류는 병목을 발생시킴.

  • collate_fn
    : variable length를 처리하기 위해 많이 씀. 길이가 다른 데이터를 padding 할 때 여기서 정의
    [[Data, Label], [Data, Label]] --> [[Data, Data], [Label, Label]]

  • drop_last
    : 데이터가 batch로 나누어 떨어지지 않고 나머지가 생길 때 마지막 배치의 길이가 달라 loss를 구하기 번거롭거나, batch size에 의존도가 높은 함수를 사용할 때 마지막 배치를 사용하지 않는 것.

profile
NLP Researcher / Information Retrieval / Search

0개의 댓글