딥러닝 - PyTorch: 데이터 다루기

dumbbelldore·2025년 1월 15일
0

zero-base 33기

목록 보기
77/97

1. Datasets

  • torch.utils.data.Dataset 클래스를 상속받아 사용자 정의 데이터셋을 만들거나, PyTorch에서 제공하는 기본 데이터셋을 사용할 수 있음

1-1. 기본 데이터셋 사용 예제

from torchvision import datasets

# MNIST 데이터셋
train_dataset = datasets.MNIST(
    root="./data", 
    train=True, 
    download=True, 
    transform=None  # 데이터 전처리 필요 시 정의
)

print(len(train_dataset))  # 데이터셋의 크기
print(train_dataset[0])    # 첫 번째 샘플 반환 (이미지, 라벨)

1-2. 사용자 정의 데이터셋 사용 예제

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

# Dataset 클래스 상속하여 정의
class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
	
    # 필수 정의: 데이터셋 전체 샘플 수 반환
    def __len__(self):
        # 데이터셋의 전체 샘플 수 반환
        return len(self.data)
	
    # 필수 정의: idx에 해당하는 데이터와 라벨 반환
    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            sample = self.transform(sample)
        
        return sample, label

# 가상 이미지 생성 (1000개의 3x32x32 이미지)
num_samples = 1000
data = np.random.rand(num_samples, 3, 32, 32)
labels = np.random.randint(0, 10, num_samples)

# 사용자 정의 데이터셋 생성
custom_dataset = CustomDataset(data, labels)

2. Transforms

  • torchvision.transforms는 데이터를 생성하거나 불러올 때 전처리를 손쉽게 적용할 수 있는 유용한 기능을 제공함

2-1. 주요 기능

  • transforms.ToTensor: 이미지를 PyTorch 텐서로 변환
  • transforms.Normalize: 이미지 데이터의 정규화
  • transforms.Compose: 여러 전처리 함수를 연속적으로 적용

2-2. 전처리 예제

from torchvision import transforms

# 데이터 전처리 내용 정의
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 데이터 로드 시 데이터 전처리 적용
train_dataset = datasets.MNIST(
    root="./data", 
    train=True, 
    download=True, 
    transform=transform
)

3. DataLoader

  • 정의한 Dataset으로부터 데이터를 로드하고, 모델에 공급하기 위해 배치로 나눠주는 역할을 수행함

3-1. 주요 매개변수

  • dataset: 데이터를 공급받을 Dataset.
  • batch_size: 한 번에 로드할 데이터 샘플의 개수.
  • shuffle: 데이터를 무작위로 섞을지 여부.
  • num_workers: 데이터를 로드할 때 사용할 프로세스의 개수.

3-2. 데이터 로딩 예제

from torch.utils.data import DataLoader

# 사전에 정의 및 전처리한 train_dataset 로드
train_loader = DataLoader(
    dataset=train_dataset, 
    batch_size=64, 
    shuffle=True, 
)

# 첫 배치의 이미지와 라벨 shape 확인
for images, labels in train_loader:
    print(images.shape, labels.shape)
    break

*이 글은 제로베이스 데이터 취업 스쿨의 강의 자료 일부를 발췌하여 작성되었습니다.

profile
데이터 분석, 데이터 사이언스 학습 저장소

0개의 댓글