SLiC (Zhao et al., 2022, arXiv)

김수빈·2022년 11월 8일
9

논문 리뷰

목록 보기
13/14
post-thumbnail

📑 Paper

Zhao, Yao, et al. "Calibrating Sequence likelihood Improves Conditional Language Generation." arXiv preprint arXiv:2210.00045 (2022).

2022년 9월 Google Brain에서 발표한 논문으로, 아직 preprint지만 관심이 있어 읽어보았다

Contribution 정리

1. SLiC를 제안함으로써 기존의 MLE Framework의 uncalibrated problem 해결

Calibration stage를 추가했을 때, Text generation task에서 SOTA 또는 그에 준한 성능을 보였다.

2. 새로운 calibration similarity metric 제안

BERTScore 수식 활용하여, input context x를 고려한 candidate, target 간 유사도 계산할 수 있는 metric을 제안하였다.

3. Decoding heuristics에 대한 필요성 제거

대다수의 text generation 연구에서는 실험을 통해 heuristic하게 최적화해야하는
Beam size optimization, Length normalization, Repetition prevention을 수행했는데,

본 연구에서는 이를 배제했음에도 불구하고 높은 성능을 유지하였다.


1. 기존 Framework의 한계

Conditional Language Generation

이 task의 경우 input context x가 주어졌을 때, target sequence y의 확률을 학습하는 방식으로 모델링한다.

생성 가능한 모든 text sequence의 확률을 계산하는 것은 힘들기 때문에
auto regressive하게 순차적으로 token-level prediction을 수행하며,

이 과정에서 Maximum Likelihood Estimation (MLE)을 학습 objective로 활용한다.

Uncalibrated Sequence Likelihood

text generation task에서 가장 이상적인 세팅은 각 input context마다 여러개의 target sequence를 가지는 것이다.
output sequence들 간의 상대적인 frequency에 따라 모델의 확률을 교정할 수 있기 때문이다

💭 이해한 내용
document마다 여러 개의 summary를 가지는 경우는 아래와 같다.
Document Peter and Elizabeth took a taxi to attend the night party in the city. While at the party, Elizabeth collapsed and was rushed to the hospital.

Summary 1 Elizabeth was hospitalized after attending a party with Peter.
Summary 2 Elizabeth fainted at the night party.
Summary 3 Peter and Elizabeth attended the night party, and Elizabeth collapsed at the party.

summary 1~3을 비교했을 때, Elizabeth, party, hospitalized/fainted/collapsed 가 주요한 내용임을 알 수 있다.
이 token들의 경우, 확률이 높게 할당되도록 학습될 것이다.

하지만 대다수의 text generation task 데이터셋은 각 input context마다, 1개의 target sequence를 가진다.

문제점 정리

  • 시퀀스들을 비교하는 직접적인 supervision이 부족, 모델의 Generalization 능력에만 의존하게 된다.

  • MLE로 학습된 모델에 대해서, 모델 확률과 시퀀스의 품질 간 상관관계가 낮다.
    이는 앞선 논문들에서 이미 밝혀진 바다.

  • Teacher Forcing으로 인한 Exposure bias는 생성된 시퀀스의 품질을 더욱 악화시킬 수 있다.

기존 프레임워크로 학습된 모델의 경우, 시퀀스들간의 비교를 통해 적절한 liklihood를 취할 능력이 부족하다.

문제를 해결하기 위한 접근 방식

  • Reinforcement Learning with sequence-level rewards
    sequence level의 reward를 학습하는 강화 학습을 이용한 방식이다.
    이전 연구은 ROUGE score, Human judgement 등을 sequence-level reward로 사용했다

  • Two stage system (Generation and Reranking)
    candidate들을 생성하는 generation 모델과, 이 후보군들을 sequence-level score에 따라 재정렬하는 reranking 모델이 존재한다.
    generation model과 reranking model은 별도의 모델로, 추가적인 computing이 요구된다

  • Multi-task learning with sequence-level loss
    Sequence-level loss를 사용하는 방식

    여기서 multi-task는 서로 다른 역할을 수행하는 loss들을 결합하여 학습하는 방식이라고 이해했다.
    이전 SOTA 논문 (Liu et al., 2022)에서는 contrastive loss와 cross entropy를 결합한 objective를 multi-task loss라고 명시했다

본 연구에서는 3번째 접근 방식을 채택하였다.

2. SLiC (Sequence Likelihood Calibration)

본 연구에서 제안하는 framework는 아래 그림과 같다.
Calibration stage를 도입함으로써, MLE framework로 학습된 모델의 calibration 능력을 보완하고자 하였다.

  • Offline learning (Batch Learning)
    대다수의 강화학습 알고리즘과는 달리 batch learning이다.
  • Evaluation metric에 직접적으로 최적화하지 않는다
    evalution metric을 학습 objective에 활용하지 않았기 때문에, evalution metric에 직접적으로 최적화할 위험이 없다.
  • 새로운 모델이 아닌 fine-tuned model을 새로운 objective에 이어서 학습한다.

여기서 SLiC는 본 연구에서 새롭게 추가한 calibration stage를 일컫는다.

Calibration Stage Algorithm

calibration stage의 동작 방식은 다음과 같다.

  1. fine-tuned model로부터 m개의 candidate을 디코딩한다.

  2. calibrated model의 초기 파라미터를 fine-tuned model 파라미터로 초기화 한다. (그냥 이어서 학습하는것;)

  3. 모델을 새로운 objective L(θ)\mathcal{L}(\theta)에 학습한다.


Calibration Stage Objective

새로운 objective는 calibration loss와 regularization loss로 구성된다.

1) Calibration loss

Target sequence와의 유사도에 따라 sequence likelihood를 정렬하기 위함이다.

Similarity function
model의 latent space 내에서, candidate과 target과의 유사도를 계산한다.

BERTScore (Zhang et al, 2020)의 수식을 활용했으며, n=1,2,4,8n=1,2,4,8에 대해,
candidate과 target 간 코사인 유사도로부터 F1 score를 계산한 값을 유사도로 사용하였다.

BERTScore & Calibration Similarity
BERTScore는 text sequence를 BERT를 통해 임베딩한 representation 간 유사도를 계산하였지만,
본 연구에서는 input context x를 고려한 representation을 취하고자 Decoder hidden state를 활용했다.

BERTScore가 s(y,y^)s(\mathbf{y}, \mathbf{\hat{y}})이라면, Calibration similarity는 s(y,y^;x)s(\mathbf{y}, \mathbf{\hat{y}};\mathbf{x})이다.

4가지 유형의 Calibration loss
Calibration loss는 4가지 유형으로 구분되며,
4개 loss 모두 positive candidate과 negative candidate간의 차이를 극대화하는 방향으로 학습한다.

postive candidate ~ gold candidate
negative candidate ~ silver candidate

  • Rank loss
    단순히 positive candidate이 negative보다 높은 확률을 갖도록 학습한다.

  • Margin loss
    positive와 negative 간 similarity 차이만큼
    positive candidate이 negative candidate보다 높은 확률을 갖도록 학습한다.

  • List-wise rank loss
    Candidate들을 similarity에 따라 정렬했을 때,
    positive와 negative 간 rank의 차이만큼 positive가 더 높은 확률을 갖도록 학습한다.

  • Expected reward loss
    전체 candidate 중에서 similarity가 가장 높은 candidate이 높은 확률을 가지도록 학습한다.

2) Regularization loss

Fine-tuned model의 확률 분포에서 크게 벗어나지 않도록 규제하기 위함이다.

  • Cross-entropy
    기존에 많이 쓰이는 loss

  • KL Divergence
    여기서 KL Divergence는 fine-tuned model과 calibrated model의 확률 분포 차이를 낮추기 위해 사용되었다.



3. 실험

1) Ablation studies of Calibration

Calibration stage를 추가함으로써 성능 향상이 이루어졌는지,
Calibration loss과 regularization loss에 따른 성능 차이 등을 확인하기 위한 ablation study 수행

여기서 Δ\Delta는 fine-tuned model 대비 calibrated model의 RmR_m 개선율이다.
RmR_m은 각 데이터셋에서 ROUGE score의 기하평균 값을 산술평균 낸 값이다.

Rm=14ΣdR1R2RL3R_m = \frac{1}{4}\Sigma_d \sqrt[3]{R_1R_2R_L}
  • Similarity function에 따른 성능

ROUGE score로 유사도를 계산한 경우가 가장 높은 성능을 보였지만, 이는 evaluation metric에 직접적으로 최적화하였기 때문이다,
본 연구에서 제안한 방식은 evaluation metric에 직접 최적화하진 않았지만 그에 준수한 성능을 내는 것을 확인할 수 있다.

ROUGE : target과의 ROUGE score를 유사도로 계산한 경우
decoder repr : 본 연구에서 제안한 방식으로, decoder hidden state간 유사도를 계산한 경우
token emb : 각 시퀀스의 토큰 임베딩 간 유사도를 계산한 경우로, input context를 고려한 값은 아님

  • Calibration loss에 따른 성능
    Rank loss를 사용했을 때, 가장 높은 성능을 보인다.

  • Regularization loss에 따른 성능
    KL Divergence를 사용했을 때가 가장 높은 성능을 보이지만 cross entropy와 큰 차이를 보이진 않는다.
    regularization loss가 없는 경우는 확실히 더 낮은 성능을 보이는 것을 알 수 있다.

  • Candidate decoding method에 따른 성능
    m개의 candidate을 디코딩하는 과정에서 decoding method에 따른 성능을 관찰한 것이다.
    Beam search를 사용했을 때 가장 높은 성능을 보인다.

  • Checkpoint selection에 따른 성능
    Calibration stage에서의 checkpoint selection 기준에 따른 성능을 관찰한 것이다.
    Perplexity를 사용했을 때 가장 높은 성능을 보인다.

2) Benefits of Calibrated sequence likelihood

finetuned only model와 Calibrated model의 Candidate 수에 따른 generation 성능 비교

Fine-tuned only model은 candidate 수(beam size)가 많아질 때, 특정 지점을 넘어서면 성능이 하락하는 양상을 보인다.
반면, Calibrated model은 candidate 수가 많아짐에 따라 성능이 향상하는 경향을 보인다.


3) Scaling Properties of Calibrated models

모델 크기 별, 모델 연산량에 따른 generation 성능 비교
FLOPs (모델 연산량)이 증가할수록 성능이 향상되는 경향이 있다.
SAMSum과 CNN/DM 데이터셋에서는, 작은 calibrated model이 큰 calibrated model에 준하는 성능을 보였다.


4) Final Result

4가지 language generation task에 대해 이전 SOTA와 성능을 비교한 것이다.
대부분에서 SOTA를 달성했으며, 또는 그에 준하는 성능을 보였다.



4. 결론

본 논문에서는 기존 프레임워크의 Uncalibrated problem 문제를 해결하고자, Calibration stage를 추가한 프레임워크를 제안하였다.

대부분의 text generation task에서 성능이 개선됨을 확인하였으며, 또한 Decoding heuristics에 대한 필요성을 제거하였다.

profile
열심히 배우는 내가 되자

0개의 댓글