MRC Base_line 코드 이해(2)

SeongGyun Hong·2024년 10월 13일

MRC Base_line 코드 이해(1)

위 페이지에 이어 작성합니다.

2. train.py가 작동하는 과정 (do_train)

python train.py --output_dir ./models/train_dataset --do_train

위 코드를 CLI에서 명령하면 main()이 작동하게 됨.

2.1 main() 실행

parser config tokenizer model 까지 설정완료 되고
이후 training_args.do_train이거나 do_eval인 경우
run_mrc 실행

2.2 run_mrc 실행

check_no_error 함수가 실행

prepare_train_features 함수가 정의되고

training_args.do_train인 경우 (본 사안) train_dataset에 map으로

prepare_train_features 함수 적용시킴.

이후 data_collator를 DatacollatorWithPadding()으로 초기화시켜 저장하고

metric 선언해줌.

다음으로 compute_metrics를 정의해주고

QuestionAnsweringTrainer를 초기화하여 trainer에 저장함.

이때 초기화는 trainer_qa.py로 넘어가서 QuestionAnsweringTrainer클라스를 불러오는 것.

이후 trainig_args.do_train이 True이므로 (--do_train)
훈련 진행.

  • 이전에 훈련 진행했었고 훈련지점 남았으면 해당 지점을 찾고 없으면 디렉터리에서 찾은 후 둘 다 없을 시 새롭게 시작
  • 훈련 결과는 train_results.txt에 저장.
  • 상태 저장 : 마지막으로 트레이너의 상태를 json 파일로 저장함.
  • 모델 및 토크나이저 저장

3. train.py가 작동하는 과정 (do_eval)

python train.py --output_dir ./outputs/train_dataset --model_name_or_path ./models/train_dataset/ --do_eval

위 CLI 명령어는 일단 모델이 저장된 경로가 있어야함. (선 학습된 모델과 토크나이저)
그리고 do_eval로 실행함

3.1 main() 실행

parser config tokenizer model 까지 설정완료 되고
이후 training_args.do_train이거나 do_eval인 경우
run_mrc 실행

3.2 run_mrc 실행

check_no_error 함수가 실행됨.(do_train과 동일)

do_eval인 경우에 실제 prepare_train_features는 쓸일이 없지만, 일단 정의되고

if training_args.do_traindo_train이 아닌 do_eval이기에 넘어가고

그 아래에 있는 prepare_validation_features가 정의되고

if training_args.do_eval이 작동하여

eval_dataset에 대해서 prepare_validation_features가 적용됨.

이후 do_train때와 같이 data_collator가 선언되고

post_processing_function이 정의된 다음

metricload_metric("squad") 담기고

compute_metrics가 정의된 이후

trainerQuestionAnsweringTrainer가 초기화 됨

  • model
  • eval_dataset(전처리 후)
  • eval_examples(전처리 전)
  • tokenizer
  • data_collator
  • post_process_function
  • compute_metrics

위의 하이퍼 파라미터가 직접적으로 사용됨.

그리고 평가가 진행되는데

실행 순서는 이하와 같음.

  1. 평가 데이터셋을 사용하여 모델 평가
  2. 평가 지표 계산
  3. 평가 샘플 수 기록
  4. 평가 지표 로깅
  5. 평가 지표 저장

4. inference.py가 작동하는 과정

python inference.py --output_dir ./outputs/test_dataset/ --dataset_name ../data/test_dataset/ --model_name_or_path ./models/train_dataset/ --do_predict

  • --output_dir 설정
  • --dataset_name 설정
  • --model_name_or_path 설정
  • --do_predict 설정

4.1 main() 실행

먼저 parser를 이용하여 model_args data_args training_args 를 받고 training_args.do_trainTrue로 설정해준다.

datasets도 로드해주고

기본적인 config tokenizer model 또한 로드해준다.

이후

eval_retrieval는 기본값이 True인 바, Sparse embedding을 사용하는
run_sparse_retrieval()가 실행된다. (datasets에 저장)

training_args.do_predictTrue인 바,
run_mrc 함수 또한 실행된다.

4.2 run_sparse_retrieval() 실행

먼저 Retriever가 초기화 되면서 retriever는 주어진 토큰화 함수와 데이터 경로를 사용하게 된다.

이후 get_sparse_embedding()메서드를 호출하여 문서 corpus에 대한 sparse embedding을 생성하고

use_faiss 인자가 True인지 여부에 따라 FAISS 사용 여부를 결정한다.
만약 사용한다면 먼저 build_faiss()메서드로 FAISS 인덱스를 구축하고
retrieve_faiss()메서드를 사용하여 검색을 수행한다.

만약 FAISS를 사용하지 않는다면
**retrieve() 메서드를 사용하여 검색을 수행하고

검색결과는 DataFrame으로 받는다.
이때 do_predict do_eval 모드에 따라 적절한 특성(Features)를 정의한다.

최종적으로 DataFrame을 Hugging Face의 Dataset 객체로 변환하고
이를 DatasetDict에 포함시켜 반환한다.

결론적으로, 주어진 Query들에 대해서 관련 문서를 검색하고, 그 결과를 MRC 모델의 입력 형식에 맞게 구조화된 데이터셋으로 변환하는 역할을 해당 함수가 하는 것.

4.3 run_mrc() 실행

run_mrc 함수는 MRC를 직접 수행하는 파트이다. 안에
prepare_validation_features
post_processing_function
compute_metrics
함수가 정의되어 있다.

먼저 run_mrc()를 실행하고 나면 데이터 준비 단계로 들어간다.

  • 데이터셋의 컬럼 이름을 확인하고 질문, 컨텍스트, 답변 컬럼을 식별하고
  • 토크나이저의 패딩 방식을 설정

이후 전처리 함수 정의를 하게되는데

  • prepare_validation_features 함수가 정의되며 검증 데이터를 전처리하는 역할을 해준다.
  • 이 함수는 텍스트를 토큰화하고, 오버플로우 토큰과 오프셋 매핑을 처리한다.

이후 데이터셋 전처리를 실제 prepare_validation_features를 이용해서 변환시켜 주고

배치 처리를 위한 data_collator를 설정한다.

이후에는 후처리 함수를 정의하는데
post_processing_function을 정의하여 모델의 예측을 원본 컨텍스트와 매칭시킨다.
post_processing_function안에는 postprocess_qa_predictions 함수를 호출하여 실제 후처리를 수행하는 알고리즘이 들어있다.

이후에 SQuAD 평가 지표를 로드하고 compute_metrics 함수를 정의하고

QuestionAnsweringTrainer를 초기화하여 모델, 데이터셋, 토크나이저 등을 설정한다.

마지막으로 평가 또는 예측을 수행하는데 아래와 같이 갈래가 나뉜다.

  • do_predict 모드:
    • trainer.predict를 호출하여 예측을 수행
    • 결과는 predictions.json 파일로 저장
  • do_eval 모드:
    • trainer.evaluate를 호출하여 모델을 평가
    • 평가 지표를 계산하고 로그로 기록

다만, 본 inference에서는 do_predict모드이므로, do_eval은 생략됨.

profile
헤매는 만큼 자기 땅이다.

0개의 댓글