원하는 형태의 Batch Data 생성

이성범·2022년 4월 26일
0

Development

목록 보기
4/7
  • 우리는 보통 torch.utils.data에 있는 Dataset과 DataLoader를 이용하여 Batch 단위에 데이터를 생성하여 모델을 학습 시킴
  • 그런데 가끔 DataLoader로 생성되는 Batch 단위 데이터의 output을 다양한 형태로 변환해주고 싶을 때가 있음 (Batch 마다 다른 padding 생성 등)
  • 이럴때 사용하는 것이 collate_fn 임
  • 여기서 중요한 점이 collate_fn을 적절하게 만들어야지 Batch 단위 데이터 생성의 시간 복잡도를 줄일 수 있음 (무의미한 반복을 사용하면 시간 복잡도가 매우 높아져 모델 학습 속도가 매우 느려짐... 실제로 전 경험을 했습니다... 코드를 제대로 수정하니깐 6분 걸리던 학습이 1분 까지 단축되었습니다!)
class CustomDataset(Dataset):
    def __init__():

    def __len__(self):

    def __getitem__(self, idx):
        feature = self.feature[idx]
        target = self.target[idx]

        return {
            'feature' : feature, 
            'target' : target,
            }

def Custom_collate_fn(samples):
    feature = [sample['feature'] for sample in samples]
    target = [sample['target'] for sample in samples]
    '''
    feature와 target을 원하는 형태로 변형
    '''

    return torch.tensor(feature, dtype = torch.float32), torch.tensor(target, dtype = torch.long)

dataset = CustomDataset()
data_loader = DataLoader(dataset, collate_fn = Custom_collate_fn)
profile
Machine Learning Engineer at Konan Technology

0개의 댓글