[Pytorch] Data augmentation algorithm - Kornia

Evo·2022년 7월 13일
0
post-thumbnail

Computer vision 분야에서 Data augmentation은 주로 아래의 알고리즘들을 주로 사용하게 되는데

  1. Torchivision
    https://pytorch.org/vision/stable/transforms.html

  2. Albumentations
    https://github.com/albumentations-team/albumentations

  3. Kornia
    https://github.com/kornia/kornia

이번 포스트에서는 kornia 알고리즘을 활용한 augmentation 코드 예제를 공유한다.

아래의 코드를 이용하면 손쉽게 augmentation이 가능하며

import kornia.augmentation as K
import torch.nn as nn
from torch import Tensor
import torch

class DataAugmentation(nn.Module):   

    def __init__(self) -> None:
        super().__init__()
        self.transforms = nn.Sequential(
            # K.RandomHorizontalFlip(p=0.5),
            K.RandomVerticalFlip(p=0.5),
            K.RandomThinPlateSpline(p=0.5),
            K.RandomAffine((0,180),p=0.5),
            K.RandomPerspective(0.5,p=0.5),            
            K.RandomGaussianNoise(mean=0,std=0.05, p=0.5),
            )
        self.cutmix = K.RandomCutMix(p=0.5)
    @torch.no_grad()  # disable gradients for effiency
    def forward(self, x: Tensor) -> Tensor:
        x_out = self.transforms(x)  # BxCxHxW
        x_out, _ = self.cutmix(x_out)
        return x_out

위와 같이 정의하고 아래 학습 과정에 데이터 증강 과정을 추가해 주면 된다.

#Training Loop 
augmentation = DataAugmentation() # 함수 생성

for data in train_loader: 
	inputs, labels = data
    inputs = augmentation(inputs) # 함수 사용
    ...
    

그 외에도 kornia에서는 다양한 augmentation 알고리즘 뿐만아니라 computer vision 어플리케이션들이 제공되기 때문에 kornia만 잘 사용하더라도 왠만한 computer vision 시스템들은 구현이 가능한것으로 보인다.

출처:https://kornia.readthedocs.io/en/latest/

profile
딥러닝 활용 연구

0개의 댓글