- 필요한 라이브러리
import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from network import CustomNet from dataset import ExampleDataset from loss import ExampleLoss
- Custom modeling
- 모델 생성
model = CustomNet() model.train() #훈련모드
- 옵티마이저 정의
params = [param for param in model.parameters() if param.requires_grad] optimizer = optim.Example(params, lr=lr)
- 손실함수 정의
loss_fn = ExampleLoss()
- Custom Dataset & DataLoader
->학습을 위한 데이터셋 생성dataset_example = ExampleDataset()
-> 학습을 위한 데이터로더 생성
dataloader_example = DataLoader(dataset_example)
Transfer Learning & Hyper Parameter Tuning
- 모델 학습
for e in range(epochs): for X,y in dataloader_example: output = model(X) loss = loss_fn(output, y) optimizer.zero_grad() loss.backward() optimizer.step()
- 데이터 전처리하기
- 데이터 불러오는 함수 생성하기 : Custom Dataset & DataLoader
- 신경망 구성하기 : Custom Model
- 오차 함수 및 최적화 기법 선택하기
- 학습 및 추론 설정 및 실행
직접 모은 데이터를 학습시키려고 할 때 Dataset과 DataLoader를 구성해야한다. 혹은 데이터가 너무 커서 메모리에 한번에 올려서 학습을 하기 어려운 경우에도 사용한다.
- Dataset의 기본 구성 요소
map-style dataset
iterable dataset은 현재 쓰지 않는다.from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self,): pass def __len__(self,): pass def __getitem__(self,idx): pass
- __init__ 메서드
일반적으로 해당 메서드에서는 데이터의 위치나 파일명과 같은 초기화 작업을 위해 동작합니다. 일반적으로 CSV파일이나 XML파일과 같은 데이터를 이때 불러옵니다. 이렇게 함으로서 모든 데이터를 메모리에 로드하지 않고 효율적으로 사용할 수 있습니다. 여기에 이미지를 처리할 transforms들을 Compose해서 정의해둡니다.
- __len__ 메서드
해당 메서드는 Dataset의 최대 요소 수를 반환하는데 사용됩니다. 해당 메서드를 통해서 현재 불러오는 데이터의 인덱스가 적절한 범위 안에 있는지 확인할 수 있습니다.
- __getitem__ 메서드
해당 메서드는 데이터셋의 idx번째 데이터를 반환하는데 사용됩니다. 일반적으로 원본 데이터를 가져와서 전처리하고 데이터 증강하는 부분이 모두 여기에서 진행될 겁니다.
transform도 여기에서 적용함.
- DataLoader가 필요한 이유
직접적으로 데이터셋을 for 반복문으로 데이터를 이용하는건 많은 특성들을 놓칠 수 밖에 없습니다. 특히, 우리는 다음과 같은 특성들을 놓친다고 할 수 있습니다.
데이터 배치, 데이터 섞기, multiprocessing 를 이용하여 병렬적으로 데이터 불러오기 등의 특성을 가지고 있습니다.
torch.utils.data.DataLoader 는 반복자로서 위에 나와있는 모든 특성들을 제공합니다.
위의 특성들을 활용해서 DataLoader class는 전체 데이터를 batch size로 slice해서 mini batch를 만들어줄 수 있다.
batch기반의 딥러닝 학습을 도와준다.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
추가 학습할 내용
sampler
collate_fn : sample list를 batch 단위로 바꾸기 위해 필요한 기능
zero-padding, 길이가 다른 데이터의 길이를 일정하게 만들어주기 위함
전체리하거나 수집한 dataset을 model에 feeding하기 위한 방법과 원하고자 하는 결과를 얻기 위해 데이터를 어떤 식으로 나눠서 학습을 진행할 것인가에 대한 고민이 필요하다.
참고자료
사용자 정의 PYTORCH DATALOADER 작성하기
https://hulk89.github.io/pytorch/2019/09/30/pytorch_dataset/
https://subinium.github.io/pytorch-dataloader/
https://wikidocs.net/16068