6. Custom Dataset과 DataLoader

ingsol·2023년 1월 7일
0

PyTorch

목록 보기
6/8
post-custom-banner

📌0. Dataset과 DataLoader

PyTorch는 데이터 처리를 위한 두 가지 클래스(torch.utils.data.Dataset 및 torch.utils.data.DataLoader)를 제공한다.
Dataset은 샘플과 정답(label)을 저장하고, DataLoader는 Dataset을 샘플에 쉽게 접근할 수 있또록 순회 가능한 객체(iterable)로 감싼다.

1. Dataset

from torch.utils.data import Dataset

class custom_dataset(Dataset):
#1. def __init__(self, [내가 필요한 것들], transforms=None):
	[생성자, 데이터 셋을 가져와서 선처리를 해준다, 필요한 변수들을 선언, 데이터를 array형태로 정리해주면 getitem에서 가져와 사용하기 수월해짐]

#2. def __len__(self):
	[데이터 셋의 길이를 반환]

#3. def __getitem__(self, idx):
	[데이터 셋에서 한 개의 데이터를 가져오는 함수를 정의]

[example1]

  • cifar-10
  • os.path: 파일 경로를 생성 및 수정하고, 파일 정보를 쉽게 다룰 수 있게 해주는 모듈
  • os.path.expanduser(path): 입력 받은 경로안의 "~"를 현재 사용자 디렉토리의 절대경로로 대체
  • 절대 경로: 최초디렉토리를 기준으로경유한 경로를 모두 기입한 전체 경로를 의미
  • 상대 경로: 절대 경로와 다르게 최초 디렉토리가 아닌 특정 경로를 기준으로 경로를 기입하는 방식. 주로 현재 작업하고 있는 폴더를 기준으로 함.
  • 출처: https://abluesnake.tistory.com/129
  • 출처: https://ok-lab.tistory.com/163
  • 출처: https://devanix.tistory.com/298
  • 📌self.train_data = self.train_data.reshape((50000, 3, 32, 32))
    :cifar10 이미지가 33232가 50000장 있음->이를 (50000, 3, 32, 32) 이렇게 array로 만들어 놓으면 getitem에서 쉽게 가져올 수 있음
from torch.utils.data import Dataset

class CIFAR10(Dataset):
	
    # class가 기본적으로 가지고 있는 data를 넣어둔 공간
	base_folder = 'cifar-10-batches-py'
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python/tar.gz"
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
    	['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb']
    ]
    
    test_list = [
    	['test_batch', '40351d587109b95175f43aff81a1287e']
    ]
    
    def __init__(self, root, train=True,
    			transform=None, target_transform=None,
                download=Fasle):
    	self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train # training set or test set
        if download:
        	self.download()
        if not self._check_integrity():
        	raise RuntimeError('Dataset not found or corrupted.'+' You can use download=True to download it.')
            
       	if self.train:
        	self.train_data = []
            self.train_labels = []
            for fentry in self.train_list:
            	f = fentry[0]
                file = os.path.join(self.root, self.base_folder, f)
                fo = open(file, 'rb')
                if sys.version_info[0] == 2:
                	entry = pickle.load(fo)
                else:
                	entry = pickle.load(fo, encoding='latin1')
                self.train_data.append(entry['data')
                if 'labels' in entry:
                	self.train_labels += entry['labels']
                else:
                	self.train_labels += entry['fine_labels']
                fo.close()
                
             self.train_data = np.concatenate(self.train_data)
             self.train_data = self.train_data.reshape((50000, 3, 32, 32)) # cifar10 이미지가 3*32*32가 50000장 있음->이를 (50000, 3, 32, 32) 이렇게 array로 만들어 놓으면 getitem에서 쉽게 가져올 수 있음
             self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC
       else:
       		f = self.test_list[0][0]
            file = os.path.join(self.root, self.base_folder, f)
            fo = open(file, 'rb')
            if sys.version_info[0] == 2:
            	entry = pickle.load(fo)
            else:
            	entry = pickle.load(fo, encoding='latin1')
            self.test_data = entry['data']
            if 'labels' in entry"
            	self.test_labels = entry['labels']
            else:
            	self.test_labels = entry['fine_labels']
            fo.close()
            self.test_data = self.test_data.reshape((10000, 3, 32, 32))
            self.test_data = self.transpose((0, 2, 3, 1)) #convert to HWC
            
	def __getitem__(self, index): # __init__에서 array 형태로 만들어 놓은 것에서 index를 가지고 바로 원하는 이미지에 접근이 가능해짐
    	if self.train:
        	img, target = self.train_data[index], self.train_labels[index]
        else:
        	img, target = self.test_data[index], self.test_labels[index]
        img = image.fromarray(img)
        
        if self.transform is not None:
        	img = self.transform(img)
        
        if self.target_transform is not None:
        	target = self.target_transform(target)
        
        return img, target
    
    def __len__(self):
    	if self.train:
        	return len(self.train_data)
        else:
        	return len(self.test_data)

[example2]

  • folder.py
  • torchvision.dataset.ImageFolder에 해당하는 내용
import torch.utils.data as Dataset

class ImageFolder(Dataset):
	def __init__(self. root, transform=None, target_transform=None, loader=default_loader):
    	classes, class_to_idx = find_classes(root)
        imgs = make_dataset(root, class_to_idx)
        if len(imgs) == 0:
        	raise(RuntimeEorror("Found 0 images in subfolders of: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
        
        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        elf.loader = loader
        
    def __getitem__(self, index):
    	path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
        	img = self.transform(img)
        if self.target_transform is not None:
        	target = self.target_transform(target)
        
        return img, target
     
   def __len__(self):
   		return len(self.imgs)
      

2. DataLoader

from torch.utils.data import Dataloader

dataloader = Dataloader(
			dataset,
            batch_size = ,
            shuffle = True
            )
post-custom-banner

0개의 댓글