collate_fn에 대해

한승수·2024년 12월 20일

AI 코딩 팁

목록 보기
5/11
post-thumbnail

DataLoader를 작성할 때 collate_fn이라는 함수가 종종 쓰이곤 합니다. 오늘은 이 collate_fn이 무엇이고, 어떤 역할을 하는지 한 번 알아보겠습니다.

collate_fn 이란?

train_loader = DataLoader(dataset = dataset,
                          batch_size=2,
                          shuffle=False,
                          collate_fn = collate_fn
                          )

위와 같이 DataLoader를 정의할 때 collate_fn 라는 인자가 있습니다. collate_fn은 Data Loader가 데이터를 모델에 전달할 때 batch 단위의 어떤 변형이 필요할 때 거치는 함수입니다.

예를 들어 Object Detection을 수행하는 모델을 학습시키고자 한다고 합시다. Object Detection 데이터셋은 한 이미지에 여러개의 object를 담고 있습니다. 따라서 한 이미지에 target의 수가 여러개일수도, 한개일수도 있습니다. 이렇게 가변적인 target을 가지는 데이터셋에 대해서 data loader는 mini batch를 어떻게 로딩해야할까요?

정답은 collate_fn을 통해서 가변적인 target으로도 mini batch가 만들어지게 하는 것입니다.

DataLoaderDataset__getitem__ 메소드를 통해 batch size 만큼의 데이터를 불러옵니다. collate_fn을 거치지 않은 데이터은 dataset의 output 유형별로 묶여 하나의 batch가 됩니다.

예를 들면, Dataset의 __getitem__이 (image, target) 이라는 tuple을 반환한다고 합시다. 만일 배치사이즈가 2인 DataLoader에서는 이 [image,image], [target,target]과 같은 형태로 묶여 model에 input됩니다.

tensor 형태로 들어가는 이 데이터들은 같은 유형끼리 묶임과 동시에 stack 됩니다.

next_batch = next(iter(train_loader))
print(next_batch[0].shape)
>>> torch.Size([2,3,256,256])

만일 앞서 얘기했던 Object Detection의 경우 target의 수가 가변적이기 때문에 collate_fn 함수 없이 실행할 경우 target tensor가 서로 겹합되지 못해 오류가 발생합니다.

RuntimeError: stack expects each tensor to be equal size, but got [19, 4] at entry 0 and [14, 4] at entry 1

이럴 때 활용해 주는 것이 collate_fn함수입니다.
collate_fn을 활용해서 각 배치가 tensor 형태로 결합되지 않고, tuple 형태로 결합되도록 하면 어떨까요?

def collate_fn(batch):
    return tuple(zip(*batch))

위의 함수는 [(이미지,타겟), (이미지,타겟)] 형태의 batch를 input으로 받아 이를 tuple 형태로 변환해줍니다.

collate_fn함수를 DataLoader에 활용하면 tensor 형태로 로딩되던 데이터가 tuple로 로딩되어 두 가변적인 데이터가 결합되지 않고 튜플의 각 원소로 저장됩니다.

train_loader = DataLoader(dataset = dataset,
                          batch_size=2,
                          shuffle=False,
                          collate_fn = collate_fn
                          )
                          
next_batch = next(iter(train_loader))

위와 같이 collate_fn을 DataLoader에 활용해주면,

print(len(next_batch[0])
>>> 2

next_batch[0][0].shape
>>> torch.Size([3,256,256])

collate_fn 함수를 통해 한 배치에 있던 이미지가 하나의 텐서로 결합되는 것이 아닌, 각각의 이미지가 tuple의 원소가 되어 모델에 로드 됩니다.

결론

이렇듯 collate_fn 함수는 데이터셋 단위가 아닌 배치 단위에서 작업을 수행해야 할 때 유용합니다. 가장 흔히 많이 쓰이는 과정은 하나의 배치에서도 서로 가변적인 길이를 갖는 문장이나 Sequence에서 길이를 맞춰주기 위해 많이 활용됩니다.
물론 반복문이나 조건문을 활용해 배치속의 일부 이미지를 마스킹하거나, 배치 내부에서의 image augmentation이 이뤄지도록 유도하는 경우도 있습니다.

오늘은 간단하게 Dataset과 DataLoader가 데이터를 어떻게 처리하고 그 과정에서 collate_fn 함수가 어떻게 활용되는지 알아보았습니다. collate_fn함수를 앞으로 유용하게 쓸 수 있도록 여러 usecase들을 참고해야겠습니다.

profile
Grooovy._.Han

0개의 댓글