Custom Dataset

Sukhun-Net·2024년 6월 22일

커스텀 데이터셋(Custom Dataset)

torch.utils.data.Dataset을 상속받아 직접 커스텀 데이터셋(Custom Dataset)을 만드는 경우

torch.utils.data.Dataset

  • PyTorch에서 데이터셋을 제공하는 추상 클래스입니다.
  • 이를 상속받아 len 메서드와 getitem 메서드를 오버라이드하여 CustomDataset에서 사용

커스텀 데이터셋을 만들 때, 가장 기본적인 뼈대는 아래와 같다. 여기서 필요한 기본적인 define은 3개이다.

class CustomDataset(torch.utils.data.Dataset): 
  def __init__(self):
  # 데이터셋의 전처리를 해주는 부분

  def __len__(self):
  # 데이터셋의 길이. 즉, 총 샘플의 수를 적어주는 부분

  def __getitem__(self, idx): 
  # 데이터셋에서 특정 1개의 샘플을 가져오는 함수
  
  • len(dataset)을 했을 때 데이터셋의 크기를 리턴할 len
  • dataset[i]을 했을 때 i번째 샘플을 가져오도록 하는 인덱싱을 위한 get_item

커스텀 데이터셋 예시


import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


# Dataset 상속
class CustomDataset(Dataset): 
  def __init__(self):  
    self.x_data = [[73, 80, 75],
                   [93, 88, 93],
                   [89, 91, 90],
                   [96, 98, 100],
                   [73, 66, 70]]
    self.y_data = [[152], [185], [180], [196], [142]]

  # 총 데이터의 개수를 리턴
  def __len__(self): 
    return len(self.x_data)

  # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴
  def __getitem__(self, idx): 
    x = torch.FloatTensor(self.x_data[idx])
    y = torch.FloatTensor(self.y_data[idx])
    return x, y


추후에 코드에서 다음과 같이 활용한다. 

__len__ 
print(len(dataset))


__getitem__ 

x, y = dataset[1]  # 1번 인덱스의 데이터 접근
print(x)  # tensor([3.0, 4.0])
print(y)  # tensor([1.0])

__init__
이 경우는 생성자(Constructor) 로써 인스턴스 초기화 담당 

심화: Self의 의미


def __init__(self)

이때, self ?
* self는 파이썬 클래스 메서드에서 인스턴스 자신을 참조하는 변수 
* self를 사용해서 클래스의 속성과 메서드에 접근할 수 있다. 
* self의 사용 예시는 다음과 같다. 


class CustomDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data        # self를 사용하여 인스턴스 속성 data를 초기화
        self.targets = targets  # self를 사용하여 인스턴스 속성 targets를 초기화

    def __len__(self):
        return len(self.data)   # self.data를 사용하여 데이터셋의 길이를 반환

    def __getitem__(self, idx):
        sample = self.data[idx]  # self.data를 사용하여 데이터의 특정 인덱스에 접근
        target = self.targets[idx]  # self.targets를 사용하여 타깃의 특정 인덱스에 접근
        return sample, target

여기서 init 메서드에서 self는 CustomDataset 클래스의 인스턴스를 가리킵니다. 예를 들어, 다음과 같이 인스턴스를 생성하면

dataset = CustomDataset(data, targets)

self는 CustomDataset의 인스턴스인 dataset을 가리키게 됩니다.

예를 들어보겠습니다.


class CustomDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data        
        self.targets = targets  


data = [1, 2, 3, 4, 5]
targets = [0, 1, 0, 1, 0]

dataset = CustomDataset(data, targets)


dataset은 CustomDataset 클래스의 인스턴스가 됩니다. 

그리고 __init__ 메서드 내에서 self.data = data와 self.targets = targets를 통해 전달된 
data와 targets를 dataset(self)의 속성으로 초기화합니다. 

따라서 이후에 dataset.data와 dataset.targets는 각각 [1, 2, 3, 4, 5][0, 1, 0, 1, 0]을 
가리키게 됩니다.
profile
Data Scientist (Computer Vision, Multimodal)

0개의 댓글