Computer vision 분야에서 Data augmentation은 주로 아래의 알고리즘들을 주로 사용하게 되는데
Torchivision
https://pytorch.org/vision/stable/transforms.html
Albumentations
https://github.com/albumentations-team/albumentations
이번 포스트에서는 kornia 알고리즘을 활용한 augmentation 코드 예제를 공유한다.
아래의 코드를 이용하면 손쉽게 augmentation이 가능하며
import kornia.augmentation as K
import torch.nn as nn
from torch import Tensor
import torch
class DataAugmentation(nn.Module):
def __init__(self) -> None:
super().__init__()
self.transforms = nn.Sequential(
# K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
K.RandomThinPlateSpline(p=0.5),
K.RandomAffine((0,180),p=0.5),
K.RandomPerspective(0.5,p=0.5),
K.RandomGaussianNoise(mean=0,std=0.05, p=0.5),
)
self.cutmix = K.RandomCutMix(p=0.5)
@torch.no_grad() # disable gradients for effiency
def forward(self, x: Tensor) -> Tensor:
x_out = self.transforms(x) # BxCxHxW
x_out, _ = self.cutmix(x_out)
return x_out
위와 같이 정의하고 아래 학습 과정에 데이터 증강 과정을 추가해 주면 된다.
#Training Loop
augmentation = DataAugmentation() # 함수 생성
for data in train_loader:
inputs, labels = data
inputs = augmentation(inputs) # 함수 사용
...
그 외에도 kornia에서는 다양한 augmentation 알고리즘 뿐만아니라 computer vision 어플리케이션들이 제공되기 때문에 kornia만 잘 사용하더라도 왠만한 computer vision 시스템들은 구현이 가능한것으로 보인다.