π μλ
νμΈμ, μ€λμ vision κ΄λ ¨ λͺ¨λΈ μμ±μ μκΈ΄νκ² μ¬μ©λλ ImageFolder Class μ¬μ©λ²μ κ°λ¨ν μμλ³΄κ³ ,
π μ΄λ₯Ό νμ©νμ¬ Custom Classλ λ§λ€μ΄λ³΄λλ‘ νκ² μ΅λλ€ :)
π Dataset classμ μΌμ’ μΌλ‘μ, Dataμ κ²½λ‘λ§ μ£Όμ΄μ§λ©΄ Dataset κ°μ²΄λ₯Ό κ°λ¨ν λ§λ€ μ μλ ν΄λμ€μ λλ€.
πΆ κ·Έλ¦¬κ³ μ λ, μ§λλ²μ μ¬μ©νλ Stanford Dog Datasetμ νμ©νλλ‘ νκ² μ΅λλ€. :)
# ubuntu linux
wget http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar
tar -xvf images.tar
μμ λͺ λ Ήμ΄λ₯Ό μννλ©΄, imagesλΌλ ν΄λ μ΄νμ κ·μ¬μ΄ κ°μμ§ μ¬μ§ λ°μ΄ν°λ₯Ό μ»μ μ μμ΅λλ€.
π μ κ° μ’μνλ 골λμ΄ μ¬μ§μ΄ μ λ€μ΄λ‘λ λ κ²μ νμΈνμμ΅λλ€!
π μ΄μ λ κ°λ¨ν ImageFolder classλ₯Ό νμ©ν΄λ³΄κ² μ΅λλ€. :)
from torchvision import transforms
size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
dog_transform = transforms.Compose([
transforms.RandomResizedCrop(
(size, size), scale=(0.5, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), # ν
μλ‘ λ³ν
transforms.Normalize(mean, std) # νμ€ν
])
dog_dataset = torchvision.datasets.ImageFolder('Images/', transform=dog_transform)
dog_dataset.class_to_idx
>>>{'n02085620-Chihuahua': 0,
'n02085782-Japanese_spaniel': 1,
'n02085936-Maltese_dog': 2,
...
π dataset.class_to_index κ°μ νμΈνμ¬, Dataset κ°μ²΄μ classμ indexκ° mapping κ΄κ³λ₯Ό νμΈν μ μμ΅λλ€.
data_loader = torch.utils.data.DataLoader(dog_dataset,
batch_size=4,
shuffle=True,
num_workers=2)
next(iter(data_loader))[0].shape, next(iter(data_loader))[1].shape
>>> (torch.Size([16, 3, 224, 224]), torch.Size([16]))
β¨ κ·Έλ¦¬κ³ , DataLoader classλ₯Ό νμ©νμ¬ μ μμ μΌλ‘ λ°μ΄ν°κ° λ½νλμ€λ κ²κΉμ§ νμΈ μλ£μ
λλ€ :)
π μ΄λμ, μ°Έ μ½μ£ ?
π μλ κ·Όλ° μ Custom Classλ₯Ό λ§λλ €κ³ νλκ±°μΌ? μ§κΈλ μ λλλ°? - λΌκ³ μκ°νμ λ€λ©΄!
π μ λ κ·Έλ κ² μκ°νμ΅λλ€λ§, μ Class μ¬μ©μ μν΄ νλμ λ§Ήμ μ΄ μμμ΅λλ€!
"class_to_idxμ μμ± κΈ°μ€" μ΄ λ°λ‘ κ·Έκ²μ΄μμ£ !
π€ κ°λ¨νκ³ κ°λ ₯ν ImageFolder Classλ μ°Έ μ’μ§λ§, class_to_idxλ "alphabet μμ"μ λ°λΌμ indexκ° κ²°μ λκ³ μμμ΅λλ€.
β λ§μ½ apple / banana / cider 3κ° labelμ΄λΌλ©΄, {"apple" : 0, "banana" : 1, "cider" : 2} μΈ κ²μ
λλ€.
π λ¬Όλ‘ , μνλ²³ μμλ κ½€λ 보νΈμ μΈ ruleμ΄μ§λ§, μ€ μ
무μμλ μνλ²³ λΌλ²¨ μμκ° μλ class_to_idx κΈ°μ€μΌλ‘λ λͺ¨λΈ νμ΅μ΄ νμν κ²½μ°κ° μμμ΅λλ€.
π ν μ€νΈλ₯Ό μν΄ μλ‘μ΄ class_to_idx κ°μ²΄λ₯Ό λ§λ€μμ΅λλ€ :)
import os
label_list = os.listdir('Images/')
custom_class_to_idx = {label : idx for idx, label in enumerate(label_list)}
# μνλ²³ μμλ‘ idx μ§μ μ΄ λμ§ μμ dict
custom_class_to_idx
>>> {'n02111500-Great_Pyrenees': 0,
'n02111889-Samoyed': 32,
'n02112018-Pomeranian': 1,
'n02112137-chow': 97,
...
π μ΄ mapping κ΄κ³λ₯Ό νμ©νμ¬ Dataset Classλ₯Ό λ§λ€ μμ μ λλ€ :)
https://pytorch.org/vision/stable/_modules/torchvision/datasets/folder.html#ImageFolder
β¨ μμ λ§ν¬λ ImageFolderμ source codeμ
λλ€.
π λ°λ‘ ImageFolder Classλ‘ μμ±λ λ΄μ©λ³΄λ€λ DatasetFolder Classμ λ‘μ§μ κ·Έλλ‘ μμλ°μ μ¬μ©νμλ€μ!
class ImageFolder(DatasetFolder):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super().__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
π κ·Έλ λ€λ©΄ DatasetFolder Classμ source codeλ νλ² λ³΄λλ‘ νμ£ .
π€£ λν! VisionDataset Classλ₯Ό λ μμλ°μλ€μ! μ΄λ λ€λ©΄ μ ClassκΉμ§ λΆμν΄μΌ νλ? μΆμ§λ§
π κ·Έλ΄ νμκΉμ§λ μμ΅λλ€. μ ν¬μκ² νμν건 μ΄λκΉμ§λ 'class_to_idx' attributeμ μμ μ΄λκΉμ.
class DatasetFolder(VisionDataset):
"""
μ£Όμ μλ΅
"""
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Find the class folders in a dataset structured as follows::
directory/
βββ class_x
β βββ xxx.ext
β βββ xxy.ext
β βββ ...
β βββ xxz.ext
βββ class_y
βββ 123.ext
βββ nsdf3.ext
βββ ...
βββ asd932_.ext
This method can be overridden to only consider
a subset of classes, or to adapt to a different dataset directory structure.
Args:
directory(str): Root directory path, corresponding to ``self.root``
Raises:
FileNotFoundError: If ``dir`` has no class folders.
Returns:
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
"""
return find_classes(directory)
β¨ μν, self.class_to_idxλ self.find_classes λ©μλλ₯Ό ν΅ν΄ κ²°μ λ©λλ€.
π κ²°κ΅ find_classes λ©μλλ§ μμ μ νλ©΄ μνλ λ°λ₯Ό μ΄λ£° μ μμ κ² κ°μ΅λλ€!
π κ·Έλ¦¬κ³ , find_classes λ©μλμ μ¬μ©λλ find_classes ν¨μμ λ΄μ©μ μλμ κ°μ΅λλ€ :)
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
π OK! Dataset Classμ μ¬μ©λ classesμ class_to_idxλ₯Ό returnν©λλ€.
(classesλ idx μμλλ‘ classκ° λμ΄λ listμ
λλ€.)
# from https://pytorch.org/vision/0.11/_modules/torchvision/datasets/folder.html
################################################################################
################################################################################
# copied from folder.py
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Union
from PIL import Image
from torchvision.datasets import VisionDataset, DatasetFolder
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path: str) -> Any:
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
def make_dataset(
directory: str,
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
See :class:`DatasetFolder` for details.
Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
by default.
"""
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
if is_valid_file(fname):
path = os.path.join(root, fname)
item = path, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {', '.join(extensions)}"
raise FileNotFoundError(msg)
return instances
# copied from folder.py END
################################################################################
################################################################################
class CustomDatasetFolder(VisionDataset):
def __init__(
self,
root: str,
loader: Callable[[str], Any],
class_list: List[str],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(class_list)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
@staticmethod
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
if class_to_idx is None:
raise ValueError("The class_to_idx parameter cannot be None.")
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
def find_classes(self, class_list: List[str]) -> Tuple[List[str], Dict[str, int]]:
return class_list, {label : idx for idx, label in enumerate(class_list)}
def __getitem__(self, index: int) -> Tuple[Any, Any]:
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)
class CustomImageFolder(CustomDatasetFolder):
def __init__(
self,
root: str,
class_list: List[str],
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super().__init__(
root,
loader,
class_list,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
π€£ ν₯ν₯... μ’ μμ±νλ€λ³΄λ κΈΈμ΄μ‘λ€μ ;<
custom_dog_dataset = CustomImageFolder('Images/', label_list, transform=dog_transform)
custom_dog_dataset.class_to_idx
>>> {'n02111500-Great_Pyrenees': 0,
'n02111889-Samoyed': 32,
'n02112018-Pomeranian': 1,
'n02112137-chow': 97,
...
β μ΄λ‘μ¨ μ κ° μνλ class_to_idx dictλ₯Ό κΈ°μ€μΌλ‘ λ§λ€μ΄μ§ Dataset classκ° μμ±λμμ΅λλ€!
π μ€λ μμ±ν κΈμ, μκ°λ³΄λ€ μ‘°κΈ μμ¬μ΄ λ©΄μ΄ μμ΅λλ€.
π€¦ββοΈ λ¨μν μμκ³Ό Overridingμ νμ©νλ©΄ κ°λ¨ν Custom classλ₯Ό λ§λ€μ μμ§ μμκΉ? νμμ§λ§..! 컨μ
μ κ°λ¨νμΌλ, μ½λλ μλΉν κΈΈμ΄μ Έλ²λ Έλ€μ.
π€ λΆμ‘±νλ λΆλΆλ μμ μ μμκ² κ°μ, μ½λλ₯Ό μ’λ λΆμν΄λ³΄κ³ , κ°λ¨ λͺ
λ£νκ² νμ©ν μ μλ λ°©μμ΄ μλμ§ 2μ°¨ κ²ν κ° νμν΄ λ³΄μ
λλ€.
π λ... κ·Έλλ μνλ κ²°κ³Όλ λμμΌλ, μ΄μ¨λ λκ±° μλκΉμ?! (ν.ν.ν.)
π μ½μ΄μ£Όμ
μ κ°μ¬λ리며, λ λ΅λλ‘ νκ² μ΅λλ€!