Dataset은 샘플과 답을 저장한다.
아래 코드에서는 샘플=image, 답=label 로 되어있다.
!pip install pytorch
...
from torch.utils.data import Dataset
...
class <CustomName> (Dataset):
def __init__(self, ...):
...
def __len__(self):
...
def __getitem__(self, ...):
...
return image, label
사용자가 정의한 Dataset 클래스에는 반드시 3개의 함수를 구현해야 합니다.
init, len, getitme
주로 이미지 데이터가 있는 디렉토리 경로를 저장하거나 데이터 수치가 적힌 csv 파일을 읽을 때 여기서 정의합니다.
def __init__(self, img_path, file, transform=None, target_transform=None):
self.img_path = img_path
self.file_data = pd.read_csv(file)
self.transform = transform
self.target_transform = target_transform
transform이 무엇인지는 나중에 설명하겠습니다.
데이터셋의 샘플 개수를 반환합니다.
def __len__(self):
return len(self.file_data)
주어진 인덱스 값에 해당하는 샘플을 데이터셋에서 불러오고 반환합니다.
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
sample = {"image": image, "label": label}
return sample
주어진 예시에서 주어진 인덱스 값인 idx에 해당하는 이미지인 image 와 이미지의 라벨인 label 을 정의합니다. 정의한 image와 label을 정의된 transform에 맞게 변경하고 sample에 저장해 사전 형으로 반환합니다. 반환하는 형식은 튜플, 리스트 등 자유입니다.
정보가 많아서 도움이 많이 됐습니다.