Pytorch - Transform(2)

나라마야·2023년 7월 23일

PyTorch

목록 보기
5/5

저번 Pytorch - Transform에서는 pytorch에서 제공하는 transform을 사용했습니다.

이번에는 "Data Augmentation"과 pytorch에서도 사용가능한 "Albumentations"에 대해서 설명하고자 합니다.

Data Augmentation

데이터 증강이 필요한 이유는 유트브 동영상블로그에 더 자세하게 적혀있습니다.

데이터 증강이 가장 중요한 이유는 모델의 과적합을 방지하고 모델을 일반화시키기 위해서 입니다.

데이터 증강을 할 때 기존 이미지에 여러 변형을 주는데, 이때 이미지의 명도, 채도, 각도 등을 바꿀 수 있습니다.

변형을 할 때 중요한 점은 이미지 고유의 의미는 변함이 없어야 한다는 것입니다. 햄스터 원본 사진을 변형했을 때 변형된 사진에서 햄스터의 의미가 사라지지 말아야 합니다.

Albumentations

공식 사이트와 예제가 있는 Github 주소입니다.

코드는 공식 사이트에 있는 예제 중 "Mask augmentation for segmentation"을 참조했습니다.

import albumentations as A

transform = A.Compose([
    A.RandomCrop(width=256, height=256),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
])

가장 기본적인 틀이 되는 transform 예제 입니다.

어떻게 이미지를 변형할 지 정의했다면 이미지 데이터셋을 가져올 때 변형을 가합니다.

class CustomDataset(Dataset):
	def __init__(self, img_path, transform):
    	...
        
	def __len__(self):
    	...
	
    def __getitem__(self, ...):
    	...
        if self.transform:
            transformed = self.transform(image=image, masks=masks)
            transformed_image = transformed['image']
            transformed_masks = transformed['masks']

        return transformed_image, transformed_masks

...

dataset = CustomDataset(img_path, transform)

위 코드에서 생략은 "..."로 표시했습니다.

변형 방식에는 여러가지가 있으며 도메인이 무엇인지, 어떤 목적을 가지고 있는지 등에 맞는 변형 방법은 달라질 수 있습니다. 여러가지 변형의 종류와 사용 방법은 위에 있는 Github에서 확인할 수 있습니다.

profile
언제나 나 자신에게 되물어 보기. So What?

0개의 댓글