Pytorch torcivision
torchvision.datasets
: torch에서 제공하는 데이터셋들
데이터셋 가져오는 방법 (기본적인 뼈대)
- 사용하고자 하는 데이터셋에 따라 조금씩 다를 수 있으니 공식문서를 참고하자.
from torchvision import datasets
from torch.utils.data import DataLoader
from config import Config
dataset = datasets.MNIST(root='dataset_root_dir',
train=True,
download=True,
transform=transform)
dataloader = DataLoader(dataset,
batch_size=Config.batch_size,
shuffle=True,
num_workers=Config.n_workers)
가져온 데이터셋 train 돌리기
def train():
for epoch in range(Config.epochs):
for batch_idx, (inputs, labels) in enumerate(dataloader):
outputs = model(inputs)
데이터셋 종류
내장되어 있는 함수들
__getitem__
- return samples[idx]
- train/test에서 iteration을 돌면서 sample을 하나씩 가져올 때 사용된다.
__len__
- return len(dataset)
- 데이터셋에 포함된 sample의 개수를 return한다.
torchvision.datasets.ImageFolder
- 경로에 저장되어 있는 데이터셋을 불러오고 싶을 때 사용한다.
- 디렉터리 구성은 이렇게 생겨야 한다.
- 예) dog/cat classification을 수행하고자 할 때
dataset = datasets.ImageFolder(root='dataset_root',
transform=transform)