[torch] collate_fn에 arguments 추가하기

‍한지영·2022년 3월 7일
4

pytorch Dataloader의 collate_fn 매개변수를 조작하면서 얻은 간단한 해결법에 대해 적는다.

pytorch는 torch.utils.data.Dataset과 torch.utils.data.DataLoader의 두 가지 도구를 제공한다.

Dataset은 input feature x와 label y를 input으로 받아 저장하며, DataLoader는 batch 기반으로 모델을 학습시키기 위해 dataset을 input으로 받아 batch size로 슬라이싱하는 역할을 한다. DataLoader에는 batchsize를 포함하여 여러가지 파라미터들이 있는데, 이 중 하나가 바로 collate_fn이다.

Collate_fn

dataset이 고정된 길이가 아닐 경우, batchsize를 2 이상으로 dataloader를 호출하면 dataloader에서 batch로 바로 못묶이고 에러가 난다. 따라서 텍스트데이터와 같이 variable length data를 다루고 batchsize를 2 이상으로 주고자 하는 경우에는 collate_fn 함수를 직접 작성해 넘겨주어야 한다.

#variable_dataset -> variable-size input인 경우
dataloader = torch.utils.data.DataLoader(variable_dataset, batch_size=2)
--------------------------------------------------------------------------------------------
#error code
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0.

Custom collate_fn()은 variable-length input을 batch로 잘 묶어서 dataloader로 넘겨주는 역할을 한다. Custom collate_fn()의 구현 예시와 custom collate_fn()을 거쳐 생성된 batched data를 dataloader로 넘겨주는 예제 코드는 아래와 같다.

# Custom collate_fn() example 1
def my_collate_fn(samples):
	inputs = [sample['input'] for sample in samples]
    labels = [sample['label'] for sample in samples]
    padded_inputs = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True) #padding
    return {'input': padded_inputs.contiguous(),
            'label': torch.stack(labels).contiguous()}
# Custom collate_fn() example 2
def my_collate_fn(batch):
  label_list, text_list, = [], []
  
  for (_text,_label) in batch:
    label_list.append(_label)
    processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
    text_list.append(processed_text)
  
  label_list = torch.tensor(label_list, dtype=torch.int64)
  text_list = pad_sequence(text_list, batch_first=True, padding_value=0)
  
  return text_list.to(device),label_list.to(device),
# feeding batched data to the dataloader
dataloader = torch.utils.data.DataLoader(variable_dataset,
                                         batch_size=2,
                                         collate_fn=my_collate_fn)

my_collate_fn() 함수를 보면 알 수 있듯이, custom collate_fn() 함수는 주로 variable-length batches를 패딩하는데에 쓰이곤 한다.

Supplying Arguments to Collate_fn(other than 'batch')

torch.utils.data.DataLoader 자체 기능으로
customed collate function의 대부분의 예시는 batch(samples)를 input으로 받아 x_feature_list와 y_label_list를 return한다. 다만 나의 경우는 batch 이외의 다른 argument(api call된 tokenizer)를 collate_fn에 넘겨주어야 하는 상황이었고, 구글링 후 간단한 해결책을 찾게 되었다.

# 예시코드
class MyCollator(object):
    def __init__(self. *params):
        self.params = params
    def __call__(self, batch):
        # do something with batch and self.params

collate_fn에 batch와 함께 넘기고자 하는 인자(나의 경우에는 tokenizer)를 init에 정의하고, 기존 collate function에서 batch를 input으로 받아 x list와 label list를 반환했던 기능을 call 함수 안에서 self.params와 함께 구현하면 된다.

# feeding to the dataloader
my_collator = MyCollator(param1, param2, ...)
data_loader = torch.utils.data.DataLoader(..., collate_fn=my_collator)

collate 과정을 거쳐 생성된 batched data를 dataloader에 feeding 하는 것은 위와 같이 Collator 인스턴스를 생성 후 data_loader에 'collate_fn=생성된 인스턴스'로 넘겨주면 된다.

Call method를 포함한 클래스를 만들어 DataLoader로 넘겨주면 해결되는 간단한 문제였다.






Reference: https://discuss.pytorch.org/t/supplying-arguments-to-collate-fn/25754/2

profile
NLP 전공 잡식성 문헌정보 석사생

1개의 댓글

comment-user-thumbnail
2023년 2월 15일

잘 읽었습니다! 아주 유용한 방법입니다! 질문이 있는데, class MyCollator(object) 에서 object는 어떤 역할을 하는 것일까요?

답글 달기