Custom Dataset 을 다루기 위해 코드를 필사 해보자.
from torchvision.datasets import VisionDataset
from typing import Any, Callable, Dict, List, Optional, Tuple
import os
from tqdm import tqdm
import os
import sys
from pathlib import Path
import requests
from skimage import io, transform
import matplotlib.pyplot as plt
import tarfile
class NotMNIST(VisionDataset):
resource_url = 'http://yaroslavvb.com/upload/notMNIST/notMNIST_large.tar.gz'
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(NotMNIST, self).__init__(root, transform = transform,
target_transform=target_transform)
# 아래 두 상황에서 data를 다운 받는다.
# - 사용자가 download 를 True 로 줬을 때
# - 데이터가 없을 때
if not self._check_exists() or download:
self.download()
self.data, self.targets = self._load_data()
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image_name = self.data[index]
image = io.imread(image_name)
label = self.targets[index]
if self.transform:
image = self.transform(image)
return iamge, label
def _load_data(self):
filepath = self.image_folder
data = []
targets = []
for target in os.listdir(filepath):
filenames = [os.path.abspath(
os.path.join(filepath, target, x)) for x in os.listdir( os.path.join(filepath,target))]
targets.extend([target] * len(filenames))
data.extend(filenames)
return data, targets
@property
def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__,'raw')
@property
def image_folder(self) -> str:
return os.path.join(self.root, 'notMNIST_large')
def download(self) -> None:
os.makedirs(self.raw_folder, exist_ok=True)
fname = self.resource_url.split("/")[-1]
chunk_size = 1024
user_agent = Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36' ##
filesize = int(requests.head(
self.resource_url,
headers={
"User-Agent" : user_agent
}).headers["Content-Length"])
with requests.get(self.resource_url, stream = True, headers={
"User-Agent" :user_agent}) as r, open(
os.path.join(self.raw_folder,fname),"wb") as f,
tqdm( unit="B", # unit string to be displayed
unit_scale = True, # tqdm 의 스케일을 결정 ( kilo, mega , ...),
unit_divisor = 1024, # unit_scale 이 True 일 때 사용된다.
total = filesize, # 전체 반복수
file = sys.stdout, # console에 표시된다. default : stderr
desc = fname # progress bar 에 표시될 prefix
) as progress:
for chunk in r.iter_content(chunk_size = chunksize):
# chunk 단위로 나뉜 file chunk 다운로드
datasize = f.write(chunk)
# 매 chunk 마다 progress bar update
progress.update(datasize)
def _extract_file(self, fname, target_path) -> None:
if fname.endswith("tar.gz"):
tag = "r:gz"
elif fname.endswith("tar"):
tag = "r:"
tar = tarfile.open(fname, tag)
tar.extractall(path=target_path)
tar.close()
def _check_exists(self) -> bool:
return os.path.exists(self.raw_folder)
dataset = NotMNIST("data",download=True)
fig = plt.figure()
for i in range(8):
sample = dataset[i]
ax = plt.subplot(1, 4, i + 1 )
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
plt.imshow(sample[0])
if i == 3:
plt.show()
break
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomCrop(224),
transforms.RandomHorizontalFlib(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
])
dataset = NotMNIST("data", download = False )
dataset = NotMNIST("data", download = False )
dataset_loader = torch.utils.data.DataLoader(dataset,
batch_size = 128, shuffle = True )
train_features, train_labels = next( iter(dataset_loader) )
train_features.shape
train.labels
train_features.shape
이 글은 커넥트 재단 Naver AI Boost Camp 교육자료를 참고했습니다.