[Pytorch] 이미지 분류기 구현 - 데이터셋 준비 및 전처리

윤형준·2022년 9월 4일
0
post-thumbnail

데이터셋 준비 및 전처리

  • 작성된 데이터로더 부분에서 CIFAR-10 train, validation, test set을 다운 받고, 이미지 데이터를 정규화 한 뒤, 데이터로더에 넣어 학습을 위한 준비하는 과정을 살펴보고 이해합니다.

  • 이때 torch.utils.data.DataLoader의 argument들에 대해서 링크를 참고하여 실습합니다.


데이터셋 준비

CIFAR-10 ?

  • CIFAR-10은 3x32x32 크기의 이미지로 총 10개의 클래스(‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’) 로 구성된 60000장의 데이터셋

  • torchvision을 사용하여 CIFAR-10 을 load 할 수 있습니다.

  • torchvision.transform 을 사용하여 데이터를 normalize할 수 있습니다.

references

dataloader(pytorch doc)

dataloader parameter(blog)

추가적으로 컴퓨터 비전 데이터를 다룰 때 torchvision이라는 패키지를 사용하면 유용하고, 여러 데이터셋들을 다운 받을 수 있습니다.

# Req. 1-1	데이터셋 준비 및 전처리, 시각화
NUM_TRAIN = 49000

# torchvision.transforms 내에는 데이터 전처리 및 data augmentation을 위한 패키지를 제공
# 본 실습에서 [0,1] 범위의 데이터셋을 [-1, 1] 범위의 값으로 normalize하도록 transform 정의
transform = T.Compose([
                T.ToTensor(),
                T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

# Dataset을 (train / val / test)으로 분리
# 분리된 Datasets를 DataLoader로 wrap하여, 추후 각 data들이 매 iteration마다 미니배치로 제공됨
cifar10_train = dset.CIFAR10('./cores/datasets/', train=True, download=True,
                             transform=transform)
loader_train = DataLoader(cifar10_train, batch_size=64, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

cifar10_val = dset.CIFAR10('./cores/datasets/', train=True, download=True,
                           transform=transform)
loader_val = DataLoader(cifar10_val, batch_size=64, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

cifar10_test = dset.CIFAR10('./cores/datasets/', train=False, download=True, 
                            transform=transform)
loader_test = DataLoader(cifar10_test, batch_size=64)

# CIFAR10의 10개의 class 정의
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

torchvision.transform.normalize

normalize는 T.Normalize((mean1, -, -), (std1, -, -)) 식으로 쓰는데
첫 번째 괄호는 mean 두 번째 괄호는 std다.
X - mean / std 형태로 정규화하므로
기존의 범위인 [0, 1]에 적용하면
(0-0.5)/0.5 = -1
(1-0.5)/0.5 = 1이 되어
[-1, 1] 범위의 값으로 normalize하도록 transform을 정의한 것이다.

데이터 시각화

  • matplotlib는 python visualization library
  • matplotlib의 imshow 함수를 데이터들을 시각화 할 수 있습니다.

references

matplotlib doc

# 이미지를 시각화하는 함수
def visualize(img):
    ################################################################################
    # TODO: 시각화를 위한 코드 작성.                                                    #
    ################################################################################
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    pass
    # 1) [-1, 1] 범위로 normalize된 데이터를 [0,1] 범위로 unnormalize 
    # 2) img를 numpy값으로 변환
    # 3) plt.imshow함수로 시각화
    img = (img + 1) / 2
    # img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    ################################################################################
    #                                 END OF YOUR CODE                             #
    ################################################################################
    

# 트레이닝 데이터를 랜덤 샘플
dataiter = iter(loader_train)
images, labels = dataiter.next() ## image

# show images
visualize(torchvision.utils.make_grid(images))

profile
매일 조금씩 성장하는 개발자

0개의 댓글