Boostcamp week12 day1 Pre/Post-processing functions

Dae Hee Lee·2021년 10월 18일
0

BoostCamp_level2_Pstage_MRC

목록 보기
10/15

Pre-processing functions, Post-processing functions

Pre-processing for train/valid의 경우 오피스아워와 피어세션을 통해 잘 살펴보았다.

그렇다면 Post-processing과정을 살펴보도록 하자.

먼저, 이전에 Post-processing의 과정에서 필요한 요소들을 여기에서 정리했었다.

  1. 불가능한 답 제거하기(다음과 같은 경우 candidate list에서 제거)
  • End position 이 start position보다 앞에 있는 경우 (e.g. start = 90, end =80)
  • 예측한 위치가 context를 벗어난 경우 (e.g. question 위치쪽에 답이 나온 경우)
  • 미리 설정한 max_answer_length 보다 길이가 더 긴 경우
  1. 최적의 답안 찾기
  • Start/end position prediction에서 score (logits)가 가장 높은 N개를 각각 찾는다. 총 2개이다.
  • 불가능한 start/end 조합을 제거한다.
  • 가능한 조합들을 score의 합이 큰 순서대로 정렬한다.
  • Score가 가장 큰 조합을 최종 예측으로 선정한다.
  • Top-k 가 필요한 경우 차례대로 내보낸다.

그렇다면, Huggingface QA code를 살펴보자.
여기에 코드가 올려져 있다.

중간에 feature_null_score를 업데이트하는 과정이 있다. 아래 코드는 For 문 안에 선언된 것으로, for 문 밖에서 min_null_prediction은 None으로 먼저 선언되었다.

feature_null_score = start_logits[0] + end_logits[0]
            if (
                min_null_prediction is None
                or min_null_prediction["score"] > feature_null_score
            ):
                min_null_prediction = {
                    "offsets": (0, 0),
                    "score": feature_null_score,
                    "start_logit": start_logits[0],
                    "end_logit": end_logits[0],
                }

start_logits[0], end_logits[0]를 활용하는 점이 신기하다. 해당하는 토큰은 CLS토큰에 해당하는 값이며, start와 end의 CLS 토큰값을 더해서 feature_null_score를 만들고 이를 계속 업데이트하는 방식이다. CLS 토큰을 정답으로 예측하기란 쉽지 않을 것이라는 방식 때문이라 생각한다.

또한 코드 중간에 아래와 같은 이중 for 문이 나오게 된다. start_indexes와 end_indexes는 각각 n_best_size만큼 원소를 가지고 있고, 따라서 아래의 결과로 인해 n_best_size의 제곱만큼의 score들을 prelim_prediction에 저장한다.(start score + end score)

그 다음 prelim_prediction을 다시 Sort, n_best_size 만큼 추출한다.

for start_index in start_indexes:
	for end_index in end_indexes:
    		...
prelim_predictions.append({
          "offsets": (
              offset_mapping[start_index][0],
              offset_mapping[end_index][1],
          ),
          "score": start_logits[start_index] + end_logits[end_index],
          "start_logit": start_logits[start_index],
          "end_logit": end_logits[end_index],
          })
          
predictions = sorted(prelim_predictions, 
		     key=lambda x: x["score"], 
             	     reverse=True)[:n_best_size]

이 다음은 해당하는 prelim prediction에서 offset 사이의 index의 text를 최종적으로 prediction으로 예측하게 되며, 파일로 저장하는 단계를 거친다.

하지만 코드에서 눈여겨 볼 점은 이렇게 단순한 경우가 아니라, version_2_with_negative값으로 나타내는, 정답이 없는 데이터셋이 포함되어있는 경우를 생각하는 것이 중요하다.

이 설명을 위해 위에서 설명했던 코드가 사용된다.

feature_null_score = start_logits[0] + end_logits[0]
            if (
                min_null_prediction is None
                or min_null_prediction["score"] > feature_null_score
            ):
                min_null_prediction = {
                    "offsets": (0, 0),
                    "score": feature_null_score,
                    "start_logit": start_logits[0],
                    "end_logit": end_logits[0],
                }

먼저 min_null_prediction은 각 example마다 Initiate되는 값이다. 그리고 for 문을 따라서 example 안의 feature들에 대해서 이 값을 update해주는데, 더 작은 feature_null_score(CLS 토큰 logit값들의 합)로 offset은 CLS토큰을 가리키며 업데이트하는 것이 핵심이다.

모든 feature에 대해 for문이 끝난 다음, 현재 example에서 계속해서 min_null_prediction값을 활용한다. 정답이 없는 데이터셋이기 때문에 prelim_prediction에 min_null_prediction을 append하고 null_score라는 변수에 min_null_prediction의 score값을 부여한다.

if version_2_with_negative:
      prelim_predictions.append(min_null_prediction)
      null_score = min_null_prediction["score"]

마지막으로 아래 코드를 통해 각 example에 대한 predict정보를 담은 all_predictions를 반환하게 된다. 중요한 것은 null_score보다 작은 score를 가지지 않아야만 text를 반환하고, null_score보다 작은 score를 가진다면 빈 텍스트를 반환하게 된다.

i = 0
while predictions[i]["text"] == "":
    i += 1
best_non_null_pred = predictions[i]

# threshold를 사용해서 null prediction을 비교합니다.
# null_score - best_non_null_pred['score']
score_diff = (
    null_score
    - best_non_null_pred["start_logit"]
    - best_non_null_pred["end_logit"]
)
scores_diff_json[example["id"]] = float(score_diff)  # JSON-serializable 가능
if score_diff > null_score_diff_threshold:
    all_predictions[example["id"]] = ""
else:
    all_predictions[example["id"]] = best_non_null_pred["text"]

위와 같은 과정을 거치는 이유를 나름대로 이해한 내용은 다음과 같다. 데이터셋에 정답이 없는 경우도 포함이라는 말은, QA Task에서 Question이 주어졌을 때 '주어진 Context에서 답을 찾을 수 없다'라는 말과 같다. 따라서 post-processing을 진행하면서 threshold를 정하고 특정 기준을 만족하는 값이 나왔을 때에 text를 반환하고, 그렇지 않으면 답이 없다고 판단할 수 있도록 후처리를 진행하는 것이다. 그리고 여기서 정한 기준이 되는 값은 threshold를 하이퍼파라미터로 줄 수도 있지만, 기본적으로 해당하는 example의 feature들 중 더 적은 CLS token들의 logit값의 합을 가지는 정보를 바탕으로 한다.

profile
Today is the day

0개의 댓글