Transforms

안소희·2024년 10월 11일

PyTorch

목록 보기
1/8

데이터는 항상 머신러닝 알고리즘 학습에 바로 사용할 수 있는 형태로 제공되지 않는다. 따라서 데이터를 변형(transform)하여 학습에 적합한 형태로 만들어야 한다.

TorchVision 데이터셋의 변형

모든 TorchVision 데이터셋은 두 개의 매개변수를 가진다:

  • transform : 특징(feature)을 변경하기 위한 함수
  • target_transform : 정답 (label)을 변경하기 위한 함수

이 매개변수들은 변형 로직을 갖는 호출 가능한 객체(callable)를 받는다

FashionMNIST 데이터셋 예시

FashionMNIST 데이터셋의 특징은 PIL Image형식이고, 정답은 정수이다. 그래서 학습을 위해서는 형태를 변환해야 한다.

  • feature : 정규화된 텐서
  • label : one-hot 인코딩된 텐서

ToTensorLambda 변형을 사용하자

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

ToTensor()

  • PIL Image나 NumPy ndarray를 FloatTensor로 변환
  • 이미지 픽셀의 크기(intensity) 값을 [0., 1.] 범위로 정규화

Lambda 변형

Lambda 변형은 사용자 정의 람다 함수를 적용한다. 위 예시에서는 정수를 원-핫 인코딩된 텐서로 변환하는 함수를 정의했다:

  1. 크기가 10인 영(zero) 텐서를 생성 (클래스의 수가 10개이므로)
  2. scatter_ 함수를 호출하여 주어진 정답 y에 해당하는 인덱스에 value=1을 할당
target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

이렇게 변형을 적용하면 FashionMNIST 데이터셋을 머신러닝 모델 학습에 적합한 형태로 준비할 수 있다.

profile
인공지능.관심 있습니다.

0개의 댓글