[U Stage] Week2_Custom Data Set 실습 ( 코드 필사 )

윰진·2022년 9월 27일
0

NaverAIBoostCamp정리

목록 보기
3/30

Custom Dataset 을 다루기 위해 코드를 필사 해보자.

Custom Dataset Practice

library import

from torchvision.datasets import VisionDataset
from typing import Any, Callable, Dict, List, Optional, Tuple
import os

from tqdm import tqdm
import os
import sys
from pathlib import Path
import requests

from skimage import io, transform
import matplotlib.pyplot as plt

NotMNIST 데이터를 다루는 클래스 작성

import tarfile

class NotMNIST(VisionDataset):
    resource_url = 'http://yaroslavvb.com/upload/notMNIST/notMNIST_large.tar.gz'
    
    def __init__(
    	self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
        ) -> None:
        super(NotMNIST, self).__init__(root, transform = transform,
       							target_transform=target_transform)
                                
        # 아래 두 상황에서 data를 다운 받는다.
        # - 사용자가 download 를 True 로 줬을 때
        # - 데이터가 없을 때
        if not self._check_exists() or download:
            self.download()
        
        self.data, self.targets = self._load_data()
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, index):
        image_name = self.data[index]
        image = io.imread(image_name)
        label = self.targets[index]
        if self.transform:
            image = self.transform(image)
        return iamge, label
        
    def _load_data(self):
        filepath = self.image_folder
        data = []
        targets = []
        
        for target in os.listdir(filepath):
            filenames = [os.path.abspath(
                os.path.join(filepath, target, x)) for x in os.listdir( os.path.join(filepath,target))]
                
            targets.extend([target] * len(filenames))
            data.extend(filenames)
            
        return data, targets
        
    @property
    def raw_folder(self) -> str:
        return os.path.join(self.root, self.__class__.__name__,'raw')
        
    @property
    def image_folder(self) -> str:
        return os.path.join(self.root, 'notMNIST_large')
        
        
    def download(self) -> None:
        os.makedirs(self.raw_folder, exist_ok=True)
        fname = self.resource_url.split("/")[-1]
        chunk_size = 1024
        
        user_agent = Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36' ##
        
        filesize = int(requests.head(
            self.resource_url,
            headers={
                "User-Agent" : user_agent
                }).headers["Content-Length"])
                
        with requests.get(self.resource_url, stream = True, headers={
            "User-Agent" :user_agent}) as r, open(
               				os.path.join(self.raw_folder,fname),"wb") as f, 
                            tqdm( unit="B", # unit string to be displayed
         						  unit_scale = True, # tqdm 의 스케일을 결정 ( kilo, mega , ...),
         						  unit_divisor = 1024, # unit_scale 이 True 일 때 사용된다.
         						  total = filesize, # 전체 반복수
         						  file = sys.stdout, # console에 표시된다. default : stderr
         						  desc = fname # progress bar 에 표시될 prefix
         						  ) as progress:
             for chunk in r.iter_content(chunk_size = chunksize):
             # chunk 단위로 나뉜 file chunk 다운로드
             datasize = f.write(chunk)
             # 매 chunk 마다 progress bar update
             progress.update(datasize)
             
    def _extract_file(self, fname, target_path) -> None:
    	if fname.endswith("tar.gz"):
            tag = "r:gz"
        elif fname.endswith("tar"):
            tag = "r:"
        tar = tarfile.open(fname, tag)
        tar.extractall(path=target_path)
        tar.close()
        
    def _check_exists(self) -> bool:
        return os.path.exists(self.raw_folder)

dataset 객체 생성

dataset = NotMNIST("data",download=True)

데이터 확인하기

fig = plt.figure()

for i in range(8):
    sample = dataset[i]
    
    ax = plt.subplot(1, 4, i + 1 ) 
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    
    plt.imshow(sample[0])
    
    if i == 3:
        plt.show()
        break

transfrom 하기

  • Data argumentation 과정
import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlib(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
        				 std = [0.229, 0.224, 0.225])
    ])
    
dataset = NotMNIST("data", download = False )

data 확인해보기

dataset = NotMNIST("data", download = False )

dataset_loader = torch.utils.data.DataLoader(dataset, 
							batch_size = 128, shuffle = True )
                            
train_features, train_labels = next( iter(dataset_loader) )

train_features.shape

train.labels

train_features.shape

이 글은 커넥트 재단 Naver AI Boost Camp 교육자료를 참고했습니다.

0개의 댓글