Boostcamp week12 Inference, Retrieval, QAtrainer

Dae Hee Lee·2021년 10월 22일
0

BoostCamp_level2_Pstage_MRC

목록 보기
14/15

Baseline의 구조에 대해 요약하려한다.

Inference

1. Main function

main함수에서는 다음과 같은 순서로 코드가 실행된다.

  1. Argument Parser
  2. Logger, Dataset 정의
  3. Pretrained Reader model 호출
  4. 다음과 같이 실행

    a. eval_retrieval = True : run_sparse_retrieval()
    b. do_eval = True / do_predict = True : run_mrc()

2. run_sparse_retrieval

run_sparse_retrieval함수는 주어진 쿼리에 대해 topk개의 passage를 Sparse retrieval을 이용해 반환하는 함수이다.

  1. SparseRetrieval 클래스 선언
  2. 저장된 SparseRetrieval 호출
  3. Faiss 유무 확인, Retreive
    a. use_faiss = True : Faiss방식으로 Retrieve
    b. use_faiss = False : Exhaustive Retrieve
  4. Predict / Eval
    a. do_predict = True : dataset feature -> context, id, question
    b. do_eval = True : dataset feature -> answers, context, id, question
  5. Dataset 반환

3. run_mrc

run_mrc는 불러온 데이터에 대해 Reader가 작동하는 부분이다.

학습을 하는 코드가 아니기 때문에 Preprocess에서 validation에 해당하는 부분만 사용한다.

단, Post process에서는 다음과 같이 구분한다.
1. do_preict = True : id, predictions 반환
2. do_eval = True : id 별 prediction과 answer 평가
3. QA trainer에 각각 predict(), evaluate() 실행

Trainer_QA

QuestionAnsweringTrainer(Trainer)

HF의 Trainer를 상속받은 Custom trainer로, evaluate과 predict함수를 선언해줬다.

먼저, 입력 파라미터로 Trainer와 다른 점은 아래 두 파라미터이다.
eval_examples : feature로 쪼개지기 전 데이터, example
post_process_function

또한 evaluate과 predict는 모두 prediction_loop함수를 이용하고 있는데, Docs를 참고하자.

prediction loop 함수를 간단히 정리하면 다음과 같다.
1. prediction_loss_only 파라미터(Default는 TrainingArgument에서 False로 되어있다.)를 받는다.
2. prediction_loss_only가 True이면 예측값을 모으지 않는다.
3. prediction_loss_only가 False이면 예측값을 모아 Metric을 계산한다.
4. Evaluation 또는 prediction을 진행한다.

아직 Prediction Loss Only를 왜 하는 것인지는 정확하게 이해하지 못했다.

1. evaluate()

  1. dataset, dataloader, eval_examples, compute_metrics를 선언한다.(혹은 인자로 받는다)
  2. prediction_loop()를 실행한 결과를 output으로 저장한다.
  3. eval_examples, eval_dataset, output.prediction, argument를 인자로 받아 Post processing을 진행한다.
  4. post process의 결과값으로 metric을 계산하여 logging한다.
  5. metric을 반환한다. -> 반환된 값은 Inference.py에서 log, save 된다.

2. predict()

  1. trainer의 get_test_dataloader 함수를 이용해 test dataloader선언한다.(HF trainer에는 get_train_dataloader, get_eval_dataloader, get_test_dataloader로 데이터로더를 선언할 수 있다.)
  2. evaluation과 동일하게 prediction_loop()를 실행한 결과를 output으로 저장한다.
  3. 만약 post_processor나, compute_metric이 정의되지 않았다면 output을 그대로 반환한다.
  4. 그게 아니라면 Post_processing을 통해 예측값을 도출하고 이를 반환한다.
  5. 예측값은 postprocess_qa_predictions에서 json형태로 저장된다.

Retrieval

Sparse Retrieval Class

Retrieve를 진행하는 클래스로, Inference과정에서 호출하는 클래스이다.

0. init

  1. 아래 코드로 context를 정의한다(corpus)
self.contexts = list(
	dict.fromkeys([v["text"] for v in wiki.values()])
        )
  1. tfidf 객체를 생성한다.
  2. p_embedding,indexer를 None으로 initialize한다.

1. get_sparse_embedding

Passage Embedding을 만들고 TFIDF와 Embedding을 pickle로 저장한다. 만약 미리 저장된 파일이 있으면 저장된 pickle을 불러온다.

  1. 저장된 tfidf 파일이 경로에 있다면 불러온다.
  2. 경로에 없다면 fit_transform()한다.

2. build_faiss

속성으로 저장되어 있는 Passage Embedding을 Faiss indexer에 fitting 시켜놓는다. 이렇게 저장된 indexer는 get_relevant_doc에서 유사도를 계산하는데 사용된다.

  1. 저장되어있는 indexer파일을 사용한다.

    Faiss는 Build하는데 시간이 오래 걸리기 때문에 매번 새롭게 build하는 것은 비효율적이다. 그렇기 때문에 build된 index 파일을 저정하고 다음에 사용할 때 불러온다.

  2. 저장되어있지 않다면 IndexFlatL2, IVFScalerQuantizer를 이용해 train, add한다.

3. retrieve

str이나 Dataset으로 이루어진 Query를 받고, str 형태인 하나의 query만 받으면 get_relevant_doc을 통해 유사도를 구한다. Dataset 형태는 query를 포함한 Huggingface.Dataset을 받고,get_relevant_doc_bulk를 통해 유사도를 구한다.

  1. Query가 string일 때
    get_relevant_doc의 결과를 받아 스코어와 context를 반환한다.
  2. Query가 dataset일 때
    get_relevant_doc_bulk의 결과를 받아 question, id, context_id, context를 반환하고 만약 context와 question, answer까지 포함된 데이터를 사용한다면 Ground Truth인 context와 answer까지 추가로 반환한다.
  3. 이러한 정보들을 가지고 있는 데이터프레임을 최종 반환한다.

4. get_relevant_doc

위에서 봤듯 하나의 query에 대한 유사도를 계산하는 함수이다.

  1. query를 tfidf vectorizer를 이용해 transform한다.
  2. query 벡터와 passage 임베딩 벡터의 dot product를 통해 result를 출력한다.
  3. dim을 맞추기 위해 squeeze한 뒤 argsort를 한다.
  4. Topk개에 대한 score, index를 최종 반환한다.

5. get_relevant_doc_bulk

위에서 봤듯 여러개의 query에 대한 유사도를 계산해 Query개수만큼 score, index들을 반환한다.

6. retrieve_faiss

retrieve와 동일하게 구성되어있으나 get_relevant_doc_faiss, get_relevant_doc_bulk_faiss를 이용한다.

7. get_relevant_doc_faiss

  1. query를 임베딩한다
  2. 미리 선언해둔 indexer를 통해 search하고, score와 index를 반환한다.

8. get_relevant_doc_bulk_faiss

위와 동일하며, 여러개의 query에 대한 결과를 반환한다.

main : Retrieval.py 실행시켰을 때

Retrieval을 실행시키면 train/valid의 query를 모두 합쳐서 retrieve 결과를 보여준다.

profile
Today is the day

0개의 댓글