torch.utils.data의 Dataset을 상속하여 내가 학습시키고자 하는 데이터를 DataLoader에 제공하는 클래스를 만드는 방법을 정리한다.
Dataset를 상속하여 Custom Dataset 클래스를 만든다.
init, getitem, len을 필수적으로 재정의해야한다.

이미지는 /celeba-dataset/img_align_celeba/img_align_celeba/에 저장되어있다.

라벨은 /celeba-dataset/list_attr_celeba.csv의 'Smiling'칼럼의 값이다. 이 파일의 'image_id'칼럼엔 사진 파일의 파일명이 저장되어있다.
따라서 csv파일을 읽어와 'Smiling'칼럼을 라벨로, 이미지폴더의 경로+'image_id'로 사진의 경로로 사진을 읽어와 반환하면 되겠다.
from torch.utils.data import DataLoader, dataset
from PIL import Image
class FaceImage(Dataset):
def __init__(self, label_file, img_dir, transform=None):
df = pd.read_csv(label_file, delimiter=",")
self.labels = df[['image_id', 'Smiling']].copy()
self.img_dir = img_dir
self.transform = transform
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.labels.iloc[idx,0])
image = Image.open(img_path).convert("RGB")
label = self.labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
return image, label
def __len__(self):
return self.labels.shape[0]
이렇게 리스트의 기능을 할 수 있도록 getitem()과 len() 메소드를 정의하면 된다.
data = FaceImage(file, target_folder, trans)
train_dataset, val_dataset, test_dataset = random_split(data, [141819, 20261, 40519])
train_loader = DataLoader(dataset = train_dataset, batch_size = 32, shuffle = True, num_workers = 1)
val_loader = DataLoader(dataset = val_dataset, batch_size = 32, shuffle = True, num_workers = 1)
test_loader = DataLoader(dataset = test_dataset, batch_size = 32, shuffle = True, num_workers = 1)
이후 정의한 클래스를 이용하여 Dataset객체를 생성 후 DataLoader에 넣어서 사용하면 된다.
이렇게 Dataset과 DataLoader를 상속하여 클래스를 정의하면 원하는 방식으로 데이터를 학습시킬 수 있겠다.