샘플된 KorQuAD 데이터에 대해서 BERT를 fine-tuning 시키는 코드를 실습해보며, max_train_samples
를 조절하면서 원하는 개수만큼 학습 데이터를 선택할 수 있다.
Point 1. Huggingface의 Question Answering 모델 확인
Point 2. 학습된 모델을 불러와 모델 평가
Extra Mission. 제공된 baseline 코드에는 대회에서 사용하는 KLUE MRC 데이터에 대해 BERT를 학습하는 부분이 있다. 추가적인 미션으로 대회에서 사용하는 KLUE MRC 데이터에 대해서 BERT를 fine-tuning 시킨 후 성능을 비교해보자. (EM 46.6 % / F1 59.9%에 가깝게 나오면 제대로 학습된 것이다.)
Generation based MRC는 지문에 정답이 존재하지 않아도 가능하며, 정답을 지문에서 찾아내는 Extraction based와 달리 정답을 Generate 하는 방식이다. 대표적인 Seq2Seq모델인 T5를 이용하여 Generation based MRC with T5를 진행해보자.
# Requirements
!pip install tqdm==4.48.0 -q
!pip install datasets==1.4.1 -q
!pip install transformers==4.5.0 -q
!pip install sentencepiece==0.1.95 -q
!pip install nltk -q
import nltk
nltk.download('punkt')
from datasets import load_dataset
from datasets import load_metric
datasets = load_dataset("squad_kor_v1")
metric = load_metric("squad")
AutoModelFroSeq2SeqLM
을 사용한다.# T5는 seq2seq 모델이므로 model을 불러올 때 AutoModelForSeq2SeqLM을 사용해야 함
config = AutoConfig.from_pretrained(
model_name,
cache_dir=None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=None,
use_fast=True,
)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
config=config,
cache_dir=None,
)
</s>
를 사용한다. 따라서 input값에 end token을 넣어주는 전처리를 수행한다.inputs = [f"question: {q} context: {c} </s>" for q, c in zip(examples["question"], examples["context"])]
targets = [f'{a["text"][0]} </s>' for a in examples['answers']]
그 다음, Tokenizer를 이용해 문장을 tokenizing한다. Max length나 padding, truncation을 결정해서 인자로 넣어준다. 결과물로 example_id
, label
값을 가진 dictionary형태의 model_inputs를 반환한다.
마지막으로 위에서 정의한 preprocess함수를 데이터셋에 적용한다.(train, validation both)
data collator
huggingface의 datacollator
와 Trainer
, trainingArguments
를 불러온다.
data collator의 경우 토크나이저 객체와 모델, label_pad_token_id
등을 인자로 준다.
Postprecessing, Compute_metrics
입력값으로서는 preds
,labels
를 받는다. 즉, 내가 예측한 Answer 텍스트와 실제 정답 텍스트를 입력으로 받게 된다.
(preds,labels)
를 Eval_preds
로 받아 batch_decode함수를 이용해 tokenized된 입력값들을 decoding한다. 이렇게 나온 decoded preds와 decoded labels를 이용해 위에서 불러왔던 metric의 compute를 통해 result를 계산하고 반환한다.