BERT 입력 구조에서 Batch 단위

Ann Jongmin·2025년 8월 12일

BERT

목록 보기
5/6

BERT 파인튜닝 시 모델에 입력하기 위해 DataLoader에서 batch 단위로 꺼내어 입력하게되는데, 데이터의 형태는 이전에 Dataset에서 데이터를 가공할때 최종적으로는 딕셔너리 형태로 만들어졌었다.

return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

이 딕셔너리 형태의 데이터를 DataLoader에서 batch단위로 (BATCH_SIZE = 16) 묶은 뒤에, 학습 단계에서 다시 batch 단위로 꺼내 모델에 입력하게되는데,

def train_epoch(model, data_loader, optimizer, device, scheduler):

    model.train()
    total_loss = 0

    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()
...

딕셔너리 데이터를 16 batch size롤 묶는 경우 실제 모델에 입력하는 데이터의 구조는 아래와 같다.

batch = {
    'input_ids': torch.tensor([
        [101, 2023, 2049, 102, 0, 0, ..., 0],   # 샘플 1의 토큰 ID (총 128개)
        [101, 2309, 1996, 102, 0, 0, ..., 0],   # 샘플 2의 토큰 ID (총 128개)
        [101, 2190, 2012, 102, 0, 0, ..., 0],   # 샘플 3의 토큰 ID (총 128개)
        [101, 2221, 2038, 102, 0, 0, ..., 0],   # 샘플 4의 토큰 ID (총 128개)
        [101, 2005, 2301, 102, 0, 0, ..., 0],   # 샘플 5의 토큰 ID (총 128개)
        [101, 2043, 2017, 102, 0, 0, ..., 0],   # 샘플 6의 토큰 ID (총 128개)
        [101, 2056, 2029, 102, 0, 0, ..., 0],   # 샘플 7의 토큰 ID (총 128개)
        [101, 2077, 2045, 102, 0, 0, ..., 0],   # 샘플 8의 토큰 ID (총 128개)
        [101, 2098, 2061, 102, 0, 0, ..., 0],   # 샘플 9의 토큰 ID (총 128개)
        [101, 2120, 2073, 102, 0, 0, ..., 0],   # 샘플 10의 토큰 ID (총 128개)
        [101, 2141, 2085, 102, 0, 0, ..., 0],   # 샘플 11의 토큰 ID (총 128개)
        [101, 2163, 2097, 102, 0, 0, ..., 0],   # 샘플 12의 토큰 ID (총 128개)
        [101, 2185, 2109, 102, 0, 0, ..., 0],   # 샘플 13의 토큰 ID (총 128개)
        [101, 2207, 2121, 102, 0, 0, ..., 0],   # 샘플 14의 토큰 ID (총 128개)
        [101, 2229, 2133, 102, 0, 0, ..., 0],   # 샘플 15의 토큰 ID (총 128개)
        [101, 2251, 2145, 102, 0, 0, ..., 0]    # 샘플 16의 토큰 ID (총 128개)
    ]),

    'attention_mask': torch.tensor([
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 1의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 2의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 3의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 4의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 5의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 6의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 7의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 8의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 9의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 10의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 11의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 12의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 13의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 14의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0],    # 샘플 15의 어텐션 마스크 (총 128개)
        [1, 1, 1, 1, 0, 0, ..., 0]     # 샘플 16의 어텐션 마스크 (총 128개)
    ]),

    'labels': torch.tensor([
        0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
    ]),

    'text': [
        "첫 번째 텍스트 샘플입니다.",
        "두 번째 텍스트 샘플입니다.",
        "세 번째 텍스트 샘플입니다.",
        "네 번째 텍스트 샘플입니다.",
        "다섯 번째 텍스트 샘플입니다.",
        "여섯 번째 텍스트 샘플입니다.",
        "일곱 번째 텍스트 샘플입니다.",
        "여덟 번째 텍스트 샘플입니다.",
        "아홉 번째 텍스트 샘플입니다.",
        "열 번째 텍스트 샘플입니다.",
        "열한 번째 텍스트 샘플입니다.",
        "열두 번째 텍스트 샘플입니다.",
        "열세 번째 텍스트 샘플입니다.",
        "열네 번째 텍스트 샘플입니다.",
        "열다섯 번째 텍스트 샘플입니다.",
        "열여섯 번째 텍스트 샘플입니다."
    ]
}
profile
AI Study

0개의 댓글