- 우리는 보통 torch.utils.data에 있는 Dataset과 DataLoader를 이용하여 Batch 단위에 데이터를 생성하여 모델을 학습 시킴
- 그런데 가끔 DataLoader로 생성되는 Batch 단위 데이터의 output을 다양한 형태로 변환해주고 싶을 때가 있음 (Batch 마다 다른 padding 생성 등)
- 이럴때 사용하는 것이 collate_fn 임
- 여기서 중요한 점이 collate_fn을 적절하게 만들어야지 Batch 단위 데이터 생성의 시간 복잡도를 줄일 수 있음 (무의미한 반복을 사용하면 시간 복잡도가 매우 높아져 모델 학습 속도가 매우 느려짐... 실제로 전 경험을 했습니다... 코드를 제대로 수정하니깐 6분 걸리던 학습이 1분 까지 단축되었습니다!)
class CustomDataset(Dataset):
def __init__():
def __len__(self):
def __getitem__(self, idx):
feature = self.feature[idx]
target = self.target[idx]
return {
'feature' : feature,
'target' : target,
}
def Custom_collate_fn(samples):
feature = [sample['feature'] for sample in samples]
target = [sample['target'] for sample in samples]
'''
feature와 target을 원하는 형태로 변형
'''
return torch.tensor(feature, dtype = torch.float32), torch.tensor(target, dtype = torch.long)
dataset = CustomDataset()
data_loader = DataLoader(dataset, collate_fn = Custom_collate_fn)