CIFAR10에 Albumentation 적용

잠만보 AI·2022년 10월 28일
0

Albumentation에 augmentation 방법이 많다고 하여 cifar10에 적용하고 코드를 공유한다.

# Cifar10 Custom Dataset code
# source: https://albumentations.ai/docs/autoalbument/examples/cifar10/

class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
    def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None):
        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label
# transformation code

transform_train = A.Compose([
    A.Normalize(),
    A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),
    A.RandomFog(fog_coef_lower=0.7, fog_coef_upper=0.8, alpha_coef=0.1, p=1),
    A.RandomSnow(brightness_coeff=2.5, snow_point_lower=0.3, snow_point_upper=0.5, p=1),
    A.AdvancedBlur(),
    A.GaussNoise(),
    # A.ISONoise(),
    A.HueSaturationValue(),
    ToTensorV2()
])

transformation용 코드를 따로 만들고 dataloader 만들 시 적용해주면 된다. Albumentation 방법은 아래 링크에서 찾을 수 있다.
https://albumentations.ai/docs/getting_started/transforms_and_targets/
특이하게 isonoise는 cifar10에 적용이 안된다. uint8 image에만 된다고 error가 뜨길래 없이 사용하였다. (정확한 이유는 아직 모른다).

trainset = Cifar10SearchDataset(root='/home/data/d/cifar10',
transform=transform_train)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)
    

인제 trainloader을 train때 써주면 된다!

P.S.
augmentation이 많다고해서 무조건 성능이 좋아지는 것은 아니다 (MSG 미연 마냥 마법의 가루가 아니라는 뜻인 것).
모델을 조금 더 robust하게 만들 수 있지만 적당한 선에서 하는게 좋다.

출처:
https://albumentations.ai/docs/autoalbument/examples/cifar10/

profile
생명공학을 전공했지만 AI에 관심있는 사람

0개의 댓글