[딥러닝] 파이토치로 딥러닝 구현 (라이브러리, 데이터 전처리, 이미지 시각화, Dataset,DataLoader)(1)

김영민·2022년 9월 20일
0

DeepLearning

목록 보기
19/33
post-custom-banner

'딥러닝 파이토치 교과서'의 코드를 따라서 짜다 보면 이해가 가긴 하는데 완벽하게 내 것이 안 되는 것 같아서 이렇게 정리해봅니다... 화이팅

라이브러리 선언

import torch
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
import os
import cv2
from PIL import Image
from tqdm import tqdm_notebook as tqdm
import random
from matplotlib import pyplot as plt
  • 기본적으로 파이토치를 사용하기 위한 : torch
  • 파이토치에서도 vision 분야를 다루기 위해서 : torchvision
  • DataLoader와 Dataset을 만들기 위해 : from torch.utils.data import DataLoader, Dataset
  • 최적화 사용을 위해 : optim
  • 파일 위치 등을 다루기 위해 : os
  • opencv 사용을 위해 : cv2
  • 이미지를 다루기 위해 : from PIL import Image
  • 시각화 : matplotlib.pyplot

1. 클래스로 이미지 전처리

  • 이미지 resize, 텐서변환, 정규화 등등

코드

class ImageTransform():
  def __init__(self,resize,mean,std): #resize할 크기, 평균, 표준편차
    self.data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(resize,scale=(0.5,1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean,std=std)
          ]),

        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CentorCrop(resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean,std=std)
          ])
      }
  def __call__(self,img,phase):
    self.data_transforms[phase](img)

        

torchvision.transforms

  • 이미지의 여러 변형 (Resize, Crop, Flip 등) 수행
    • RandomResizedCrop : scale 범위 내에서 랜덤하게 자르고, 입력된 resize만큼 크기 조정
    • RandomHorizontalFlip : p(확률) 의 비율만큼 랜덤하게 좌우 전환 (default : p=0.5)
    • ToTensor : PIL Image나 numpy.ndarray 를 tensor로 변환해줌
    • Normalize : 각 채널마다 정규화해주는 것

2. 이미지 불러오기 및 데이터 나누기

코드

cat_directory = r'/content/drive/MyDrive/dogs-vs-cats/Cat'
dog_directory = r'/content/drive/MyDrive/dogs-vs-cats/Dog'

cat_images_filepaths = sorted([os.path.join(cat_directory, f) for f in os.listdir(cat_directory)])
dog_images_filepaths = sorted([os.path.join(dog_directory, f) for f in os.listdir(dog_directory)])

images_filepaths = [*cat_images_filepaths, *dog_images_filepaths]

correct_images_filepaths = [i for i in images_filepaths if cv2.imread(i) is not None)

random.seed(42)
random.shuffle(correct_images_filepaths)
train_images_filepaths = correct_images_filepaths[:400]
val_images_filepaths = correct_images_filepaths[400:-10]
test_images_filepaths = correct_images_filepaths[-10:]

len(train_images_filepaths), len(val_images_filepaths),len(test_images_filepaths)

3. 이미지 출력하기

def display_image_grid(images_filepaths, predicted_labels=(), cols=5):

  rows = len(images_filepaths) // cols

  figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12,6))

  for i, image_filepath in enumerate(images_filepaths):
    image = cv2.imread(image_filepath)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    true_label = os.path.normpath(image_filepath).split(os.sep)[-2]
    predicted_label = predicted_label[i] if predicted_labels else true_label

    color = "green" if true_label==predicted_label else "red"

    ax.ravel()[i].imshow(image)
    ax.ravel()[i].set_title(predicted_label,color=color)
    ax.ravel()[i].set_axis_off()
  
  plt.tight_layout()
  plt.show()
  
display_image_grid(test_images_filepaths)

출력 결과

  • 이미지 출력을 grid로 만들어서 출력
    • plt.subplots: (figure, axes) 튜플을 반환 / figure -> 전체, axes -> 그래프 각각
    • cv2.imread: 이미지 파일을 입력하면 이미지 반환
    • cv2.cvtColor: opencv는 BGR 순서로 색상이 되어있으므로 RGB로 바꿔줘야한다.
    • os.path.normpath : 파일 정규화. //나 / 로 되어 있는 것들을 모두 / 로 통일
    • .split(input): input 기준으로 잘라서 list 반환
    • os.sep : '/'
    • ravel() : ax에 하나하나 접근 가능 ( 1차원으로 바꿔줌 )

4. Dataset 정의 및 사용

코드 (정의)

class DogvsCatDataset(Dataset):
  def __init__(self, file_list, transform=None, phase = 'train'):
    self.file_list = file_list
    self.transform = transform
    self.phase = phase
  
  def __len__(self):
    return len(self.file_list)

  def __getitem__(self, idx):
    img_path = self.file_list[idx]
    img = Image.open(img_path)
    img_transformed = self.transform(img,self.phase)
    label = img_path.split('/')[-1].split('.')[0]

    if label == 'dog':
      label = 1
    elif label == 'cat':
      label = 0
    return img_transformed, label
  • dataset 정의
    • __init__ : file_list, transform, phase 입력 받음.
    • __len__ : file_list의 길이 반환
    • __getitem__ : idx번째 이미지와 라벨 반환

코드 (사용)

train_dataset = DogvsCatDataset(
    file_list = train_images_filepaths,
    transform=ImageTransform(size,mean,std),
    phase = 'train'
)

val_dataset = DogvsCatDataset(
    file_list = val_images_filepaths,
    transform=ImageTransform(size,mean,std),
    phase = 'val'
)
  • 위에서 만든 ImageTransform 적용 ( 개체에 저장하면 자동으로 ImageTransforms의 call부분 return 됨)

5. DataLoader

코드

train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size,shuffle=True, num_workers=2)

dataloader_dict = {'train': train_loader, 'test': val_loader}

post-custom-banner

0개의 댓글