4. 파이토치(PyTorch) 튜토리얼 - 변형(Transform)

Yeonghyeon·2022년 7월 29일
0
post-custom-banner

본 포스팅은 파이토치(PYTORCH) 한국어 튜토리얼을 참고하여 공부하고 정리한 글임을 밝힙니다.


변형(Transform)

데이터가 항상 머신러닝 알고리즘 학습에 필요한 최종 처리가 된 형태로 제공되지 않음
변형(Transform)을 통해 데이터를 조작하고 학습에 적합하게 만듦

모든 TorchVision 데이터셋들은 변형 로직을 갖는, 호출 가능한 객체를 받는 매개변수 두개를 가짐

  • feature를 변형하기 위한 transform
  • 정답을 변형하기위한 target_transform

FashionMNIST의 feature: PIL Image 형식
FashionMNIST의 label: 정수

이들을 학습하려면 정규화된 텐서 형태의 feature와 원-핫으로 encode된 텐서 형태의 정답이 필요!
➡️ 변형(transformation)을 위해 ToTensorLambda 사용

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

ToTensor()

PIL Image나 NumPy ndarray를 FloatTensor로 변환하고, 이미지의 픽셀 크기 값을 [0., 1.] 범위로 비례하여 조정(scale)함

Lambda 변형(Transform)

사용자 정의 람다(lambda) 함수를 적용하여 정수를 원-핫으로 인코드된 텐서로 바꾸는 함수를 정의
먼저 데이터셋의 정답 개수인 10짜리 zero tensor를 만들고, scatter_를 호출하여 주어진 정답 y에 해당하는 인덱스에 value=을 할당

target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
post-custom-banner

0개의 댓글