PyTorch는 데이터 처리를 위한 두 가지 클래스(torch.utils.data.Dataset 및 torch.utils.data.DataLoader)를 제공한다.
Dataset은 샘플과 정답(label)을 저장하고, DataLoader는 Dataset을 샘플에 쉽게 접근할 수 있또록 순회 가능한 객체(iterable)로 감싼다.
from torch.utils.data import Dataset
class custom_dataset(Dataset):
#1. def __init__(self, [내가 필요한 것들], transforms=None):
[생성자, 데이터 셋을 가져와서 선처리를 해준다, 필요한 변수들을 선언, 데이터를 array형태로 정리해주면 getitem에서 가져와 사용하기 수월해짐]
#2. def __len__(self):
[데이터 셋의 길이를 반환]
#3. def __getitem__(self, idx):
[데이터 셋에서 한 개의 데이터를 가져오는 함수를 정의]
from torch.utils.data import Dataset
class CIFAR10(Dataset):
# class가 기본적으로 가지고 있는 data를 넣어둔 공간
base_folder = 'cifar-10-batches-py'
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python/tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb']
]
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e']
]
def __init__(self, root, train=True,
transform=None, target_transform=None,
download=Fasle):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.'+' You can use download=True to download it.')
if self.train:
self.train_data = []
self.train_labels = []
for fentry in self.train_list:
f = fentry[0]
file = os.path.join(self.root, self.base_folder, f)
fo = open(file, 'rb')
if sys.version_info[0] == 2:
entry = pickle.load(fo)
else:
entry = pickle.load(fo, encoding='latin1')
self.train_data.append(entry['data')
if 'labels' in entry:
self.train_labels += entry['labels']
else:
self.train_labels += entry['fine_labels']
fo.close()
self.train_data = np.concatenate(self.train_data)
self.train_data = self.train_data.reshape((50000, 3, 32, 32)) # cifar10 이미지가 3*32*32가 50000장 있음->이를 (50000, 3, 32, 32) 이렇게 array로 만들어 놓으면 getitem에서 쉽게 가져올 수 있음
self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC
else:
f = self.test_list[0][0]
file = os.path.join(self.root, self.base_folder, f)
fo = open(file, 'rb')
if sys.version_info[0] == 2:
entry = pickle.load(fo)
else:
entry = pickle.load(fo, encoding='latin1')
self.test_data = entry['data']
if 'labels' in entry"
self.test_labels = entry['labels']
else:
self.test_labels = entry['fine_labels']
fo.close()
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
self.test_data = self.transpose((0, 2, 3, 1)) #convert to HWC
def __getitem__(self, index): # __init__에서 array 형태로 만들어 놓은 것에서 index를 가지고 바로 원하는 이미지에 접근이 가능해짐
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
img = image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
import torch.utils.data as Dataset
class ImageFolder(Dataset):
def __init__(self. root, transform=None, target_transform=None, loader=default_loader):
classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx)
if len(imgs) == 0:
raise(RuntimeEorror("Found 0 images in subfolders of: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.target_transform = target_transform
elf.loader = loader
def __getitem__(self, index):
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.imgs)
from torch.utils.data import Dataloader
dataloader = Dataloader(
dataset,
batch_size = ,
shuffle = True
)