PyTorch - Dataset

나라마야·2023년 7월 17일

PyTorch

목록 보기
1/5

공식 파이토치 사이트 (한국 사용자)

Dataset은 샘플을 저장한다.
아래 코드에서는 샘플=image, 답=label 로 되어있다.

!pip install pytorch

...

from torch.utils.data import Dataset

...

class <CustomName> (Dataset):
	def __init__(self, ...):
    	...
        
	def __len__(self):
    	...
	
    def __getitem__(self, ...):
    	...
        return image, label

사용자가 정의한 Dataset 클래스에는 반드시 3개의 함수를 구현해야 합니다.

init, len, getitme

init

주로 이미지 데이터가 있는 디렉토리 경로를 저장하거나 데이터 수치가 적힌 csv 파일을 읽을 때 여기서 정의합니다.

def __init__(self, img_path, file, transform=None, target_transform=None):
	self.img_path = img_path
    self.file_data = pd.read_csv(file)
    self.transform = transform
    self.target_transform = target_transform

transform이 무엇인지는 나중에 설명하겠습니다.

len

데이터셋의 샘플 개수를 반환합니다.

def __len__(self):
    return len(self.file_data)

getitem

주어진 인덱스 값에 해당하는 샘플을 데이터셋에서 불러오고 반환합니다.

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    
    if self.transform:
        image = self.transform(image)
        
    if self.target_transform:
        label = self.target_transform(label)
    sample = {"image": image, "label": label}
    return sample

주어진 예시에서 주어진 인덱스 값인 idx에 해당하는 이미지인 image 와 이미지의 라벨인 label 을 정의합니다. 정의한 image와 label을 정의된 transform에 맞게 변경하고 sample에 저장해 사전 형으로 반환합니다. 반환하는 형식은 튜플, 리스트 등 자유입니다.

profile
언제나 나 자신에게 되물어 보기. So What?

1개의 댓글

comment-user-thumbnail
2023년 7월 18일

정보가 많아서 도움이 많이 됐습니다.

답글 달기