torchvision.datasets이란? | Pytorch에서 제공하는 데이터셋 사용하기

Seohyun·2023년 8월 3일
0

정리

목록 보기
2/9
post-thumbnail
post-custom-banner

Pytorch torcivision

  • torchvision.datasets: torch에서 제공하는 데이터셋들

데이터셋 가져오는 방법 (기본적인 뼈대)

  • 사용하고자 하는 데이터셋에 따라 조금씩 다를 수 있으니 공식문서를 참고하자.
from torchvision import datasets
from torch.utils.data import DataLoader

# Config: hyperparameter 저장해두는 class
from config import Config

# dataset 불러오기
dataset = datasets.MNIST(root='dataset_root_dir',
						# train=True는 train set을, train=False는 test set을 불러온다.
						train=True,
                        download=True,
                        # 이미지/영상 데이터에 transform을 정의해 적용할 수 있다.
                        transform=transform)

# dataset을 batch_size 크기로 잘라준다고 생각하면 된다.
dataloader = DataLoader(dataset,
                        batch_size=Config.batch_size,
                        # 보통 train set은 shuffle을 True로, test set은 False로 설정한다. 
                        shuffle=True,
                        # 여러 데이터를 병렬로 가져오기 위한 n_workers
       					num_workers=Config.n_workers)

가져온 데이터셋 train 돌리기

def train():
	
    for epoch in range(Config.epochs):
    	for batch_idx, (inputs, labels) in enumerate(dataloader):

            outputs = model(inputs)
            # blah blah

데이터셋 종류

내장되어 있는 함수들

  • __getitem__
    • return samples[idx]
    • train/test에서 iteration을 돌면서 sample을 하나씩 가져올 때 사용된다.
  • __len__
    • return len(dataset)
    • 데이터셋에 포함된 sample의 개수를 return한다.

torchvision.datasets.ImageFolder

  • 경로에 저장되어 있는 데이터셋을 불러오고 싶을 때 사용한다.
  • 디렉터리 구성은 이렇게 생겨야 한다.
    • 예) dog/cat classification을 수행하고자 할 때
      • root
           |__ dog
           		|__ xxx.png
                |__ xxy.png
                |__xxz.png
           |__ cat
           		|__ cat1.png
                |__ c.png
                
dataset = datasets.ImageFolder(root='dataset_root',
							   transform=transform)
profile
Hail hamster
post-custom-banner

0개의 댓글