Pytorch - Transform(1)

나라마야·2023년 7월 22일

PyTorch

목록 보기
4/5

데이터 변형은 어느 모델을 학습시킬 때든 중요합니다. pytorch에서 기본적으로 제공하는 Transform에 대해 먼저 알아보겠습니다.

Pytorch 튜토리얼 - Transform

공식사이트 - 한국 사용자 튜토리얼

TorchVision의 데이터셋들은 1. 특징을 변경하기 위한 transform, 2. 정답을 변경하기 위한 target_transform, 이렇게 2개의 매개변수를 받습니다.

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

ToTensor()

이미지 형식을 모델이 사용할 수 있는 FloatTensor로 변환하고,
이미지 픽셀 범위를 [0, 1]로 비례하여 조정합니다.

Lambda 변형

사용자 정의 람다 함수를 적용합니다. 여기서는 정수를 원핫 인코딩을 적용했습니다.

이렇게 튜토리얼 사이트에 올라온 형식으로는 변형 방식을 사용자 정의 람다 함수를 사용해 복잡합니다. 다음으로 Docs에서 더 자세하게 알아 볼 수 있습니다.

Pytorch Docs - Transforming and augmenting images

공식 사이트

우선 코드를 보여드리겠습니다.

transforms = torch.nn.Sequential(
    transforms.CenterCrop(10),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)

transform.<원하는 변형> 방식으로 기존 이미지를 변형시킵니다.

pytorch에서 제공하는 변형 방식 목록은 공식 사이트에서 확인하실 수 있습니다.

profile
언제나 나 자신에게 되물어 보기. So What?

1개의 댓글

comment-user-thumbnail
2023년 7월 22일

개발자로서 성장하는 데 큰 도움이 된 글이었습니다. 감사합니다.

답글 달기