갖고 있는 데이터 셋을 여러가지 방법으로 augment 하여 실질적인 학습 데이터 셋의 규모를 키울 수 있는 방법
import random
import numpy as np
import os
import cv2
import glob
from PIL import Image
import PIL.ImageOps
num_augmented_images = 50
file_path = "./data/"
file_names = os.listdir(file_path)
print(file_names)
total_origin_image_num = len(file_names)
print("total image number >>", total_origin_image_num)
augment_cnt = 1
for i in range(1, num_augmented_images):
change_picture_index = random.randint(0, total_origin_image_num-1)
file_name = file_names[change_picture_index]
os.makedirs("./custom_data", exist_ok=True)
origin_image_path = file_path + file_name
print(origin_image_path)
image = Image.open(origin_image_path)
# print(image)
random_augment = random.randrange(1, 7)
# print(random_augment)
if (random_augment == 1):
# 좌우 반전
inverted_image = image.transpose(Image.FLIP_LEFT_RIGHT)
inverted_image.save("./custom_data/" + "inverted_" + str(augment_cnt) + ".png")
elif (random_augment == 2):
# 기울기
rotated_image = image.rotate(random.randrange(-20, 20))
rotated_image.save("./custom_data/" + "rotated_" + str(augment_cnt) + ".png")
elif (random_augment == 3):
# 리사이즈
resized_image = image.resize(size=(224,224))
resized_image.save("./custom_data/" + "resized_" + str(augment_cnt) + ".png")
elif (random_augment == 4):
# 상하반전
inverted_top_bottom_image = image.transpose(Image.FLIP_TOP_BOTTOM)
inverted_top_bottom_image.save("./custom_data/" + "inverted_top_bottom_" + str(augment_cnt) + ".png")
elif (random_augment == 5):
# 그레이스케일
gray_image = image.convert('L')
gray_image.save("./custom_data/" + "gray_" + str(augment_cnt) + ".png")
elif (random_augment == 6):
# 색상 변경
np_image = np.array(image)
colorchanged_image = Image.fromarray(cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB))
colorchanged_image.save("./custom_data/" + "colorchanged_" + str(augment_cnt) + ".png")
augment_cnt += 1
import time
import torch
import torchvision
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from matplotlib import pyplot as plt
import albumentations
from albumentations.pytorch import ToTensorV2
import os
# dataset
class TorchvisionDataset(Dataset):
def __init__(self, file_path, transform=None):
self.file_path = file_path
# self.labels = labels
self.transform = transform
def __getitem__(self, index):
# label = self.labels[index]
file_path = self.file_path[index]
image = Image.open(file_path)
start_t = time.time()
if self.transform:
image = self.transform(image)
total_time = (time.time() - start_t)
return image, total_time
def __len__(self):
return len(self.file_path)
torchvision_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
torchvision_dataset = TorchvisionDataset(
file_path = [f"./dolphin/{os.listdir('./dolphin/')[i]}" for i in range(0, len(os.listdir('./dolphin/')))],
transform=torchvision_transform
)
os.makedirs('./dolphin_augmented', exist_ok=True)
total_time = 0
for i in range(0, len(os.listdir('./dolphin/'))):
sample, transform_time = torchvision_dataset[i]
img = transforms.ToPILImage()(sample)
img.save(f'./dolphin_augmented/aug_{i}.png', 'png')
total_time += transform_time
print("torchvision time / sample : {} ms ".format(total_time*10))