[AI] Data Augmentation

Bora Kwon·2022년 5월 24일

Data Augmentation

갖고 있는 데이터 셋을 여러가지 방법으로 augment 하여 실질적인 학습 데이터 셋의 규모를 키울 수 있는 방법

  • Flipping (이미지 뒤집기)
  • Brightness (이미지 밝기 변경)
  • Resize (이미지 크기 변경)
    등등...
import random
import numpy as np
import os
import cv2
import glob
from PIL import Image
import PIL.ImageOps

num_augmented_images = 50

file_path = "./data/"

file_names = os.listdir(file_path)
print(file_names)

total_origin_image_num = len(file_names)
print("total image number >>", total_origin_image_num)

augment_cnt = 1

for i in range(1, num_augmented_images):
    change_picture_index = random.randint(0, total_origin_image_num-1)
    file_name = file_names[change_picture_index]
    os.makedirs("./custom_data", exist_ok=True)
    origin_image_path = file_path + file_name
    print(origin_image_path)

    image = Image.open(origin_image_path)
    # print(image)

    random_augment = random.randrange(1, 7)
    # print(random_augment)

    if (random_augment == 1):
        # 좌우 반전
        inverted_image = image.transpose(Image.FLIP_LEFT_RIGHT)
        inverted_image.save("./custom_data/" + "inverted_" + str(augment_cnt) + ".png")
        
    elif (random_augment == 2):
        # 기울기
        rotated_image = image.rotate(random.randrange(-20, 20))
        rotated_image.save("./custom_data/" + "rotated_" + str(augment_cnt) + ".png")

    elif (random_augment == 3):
        # 리사이즈
        resized_image = image.resize(size=(224,224))
        resized_image.save("./custom_data/" + "resized_" + str(augment_cnt) + ".png")

    elif (random_augment == 4):
        # 상하반전
        inverted_top_bottom_image = image.transpose(Image.FLIP_TOP_BOTTOM)
        inverted_top_bottom_image.save("./custom_data/" + "inverted_top_bottom_" + str(augment_cnt) + ".png")

    elif (random_augment == 5):
        # 그레이스케일
        gray_image = image.convert('L')
        gray_image.save("./custom_data/" + "gray_" + str(augment_cnt) + ".png")
       
    elif (random_augment == 6):
        # 색상 변경
        np_image = np.array(image)
        colorchanged_image = Image.fromarray(cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB))
        colorchanged_image.save("./custom_data/" + "colorchanged_" + str(augment_cnt) + ".png")


    augment_cnt += 1

pytorch를 활용한 transform

import time

import torch
import torchvision
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from matplotlib import pyplot as plt
import albumentations
from albumentations.pytorch import ToTensorV2
import os

# dataset
class TorchvisionDataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        # self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        # label = self.labels[index]
        file_path = self.file_path[index]
        image = Image.open(file_path)

        start_t = time.time()
        if self.transform:
            image = self.transform(image)
        total_time = (time.time() - start_t)

        return image, total_time

    def __len__(self):
        return len(self.file_path)

torchvision_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),

])

torchvision_dataset = TorchvisionDataset(
    file_path = [f"./dolphin/{os.listdir('./dolphin/')[i]}" for i in range(0, len(os.listdir('./dolphin/')))],
    transform=torchvision_transform
)

os.makedirs('./dolphin_augmented', exist_ok=True)

total_time = 0

for i in range(0, len(os.listdir('./dolphin/'))):
    sample, transform_time = torchvision_dataset[i]
    img = transforms.ToPILImage()(sample)
    img.save(f'./dolphin_augmented/aug_{i}.png', 'png')
    
    total_time += transform_time

print("torchvision time / sample : {} ms ".format(total_time*10))
profile
Software Developer

0개의 댓글