__init__(self): 생성자def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform # 데이터셋에 적용할 전처리 파이프라인
# 'no'와 'yes' 디렉터리의 모든 파일 경로를 리스트로 저장
self.image_paths = []
self.labels = []
for label, sub_dir in enumerate(["no", "yes"]): # 'yes'는 1, 'no'는 0으로 라벨링
sub_dir_path = os.path.join(data_dir, sub_dir)
# os.listdir: 하위 디렉토리의 모든 파일 이름 얻기
for file_name in os.listdir(sub_dir_path):
file_path = os.path.join(sub_dir_path, file_name)
# os.path.isfile: 각 파일 존재 여부 확인
if os.path.isfile(file_path):
self.image_paths.append(file_path) # 이미지 파일 경로 저장
self.labels.append(label) # 이미지 라벨 저장
__len__(): 데이터셋 크기 반환 def __len__(self):
return len(self.image_paths)
DataLoader에서 배치를 생성하거나 에포크를 정의할 수 있음__getitem__(index): 샘플 가져오기def __getitem__(self, index):
img_path = self.image_paths[index] # index에 해당하는 이미지 파일 경로 가져오기
label = self.labels[index] # index에 해당하는 이미지 라벨 가져오기
# 이미지 로드
image = Image.open(img_path).convert("RGB")
# 데이터 변환 적용
# transform(정규화, 텐서 변환 등을 수행) 따로 정의해주어야 함
if self.transform:
image = self.transform(image)
return image, label
import torch.utils.data
tr_dataset = BasicDataset(train_x, train_y)
train_loader = data.DataLoader(dataset=tr_dataset, batch_size=128, num_workers=8, shuffle=True)
tt_dataset = BasicDataset(test_x, test_y)
test_loader = data.DataLoader(dataset=tt_dataset, batch_size=128, num_workers=8, shuffle=False)
batch_sizeshuffleTrue로 설정, 검증/테스트 데이터에서는 False로 설정num_workersnum_workers, prefetch_factor, persistent_workersdrop_last[PyTorch] Dataset과 Dataloader 설명 및 custom dataset & dataloader 만들기