[NLP23-1_2] BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension(2020, ACL)

fla1512·2023년 2월 20일
1

NLP Study

목록 보기
15/23

1 Introduction

이전의 연구 방향과 한계점

  • self-supervised 방법들은 NLP의 많은 분야에서 좋은 성과를 거둠
    • 예) Word2Vec, ELMo, BERT, SpanBERT, XLNet, Roberta
  • 이 중 가장 성공적인 성과는, MLM(masked language models)의 변형
    • = denoising autoencoders, 문장 내 존재하는 단어의 집합이 가려진 텍스트를 다시 재구축하는 방식
  • 최신 연구는 MLM의 다양한 변형으로 발전
    • 예) masked tokens이 예측되는 순서(xlnet), masked tokens의 분포를 개선(spanbert), masked token을 교체할 context를 개선(UniLM)
  • 하지만, 이런 방법들은 특정 end task(예. span prediction, generation)에 초점을 두어서 -> 활용성에 한계가 있음

논문의 해결책: BART

  • 특징
    • Bidirectional이고 Auto-Regressive한 Transformers를 결합한 모델
    • denoising autoencoder임
    • sequence-to-sequence model로 만들어져서 여러 분야의 end task에 적용 가능
    • pretraining의 두 과정
      • 1) text는 arbitrary noising function으로 오염됨
      • 2) sequence-to-sequence model이 본 문장을 재구축하고자 학습됨
    • standard한 Tranformer 기반의 neural machine translation architecture를 사용(Fig1)
      • 그로 인해 BERT (bidirectional encoder), GPT (left-to-right decoder), 그리고 다른 방식들을 일반화하는 것처럼 보임

Method

  • 가장 큰 이점은 noising flexibility
    • = 임의의 변형이 원본 문장에 적용될 수 있다는 뜻(길이를 바꾸는 것을 포함해서)

      해당 부분부터는 논문 뒤에서 상세하게 설명해준다 !

  • 본 연구에서는 몇 noising 방법들을 평가해서 최고의 성능을 찾고자 함
    • 1) 본 문장의 순서를 임의로 섞기
    • 2) in-filling scheme을 사용하기
  • 이런 방법들은 1) original word masking과 BERT에서의 2) next sentence predicition를 일반화해서 -> 모델이 전체적인 문장 길이에 대해서 더 합리적이게 된다 (=transformation을 하기 위한 과정으로, input에 더 긴 문장을 넣는 것을 가능하게 한다)

Experiment & Result

  • text generation에서 fine tuned되었을 때+ comprehension task 특히 효율적
  • comprehension task는 RoBERTa와 GLUE, SQuAD에서 좋은 성과를 거두었고, 'abstractive dialogue, question answering, summarization tasks'에서 SOTA 달성
    • 예시로, XSum에서 이전 연구보다 성능을 6 ROUGE 올림
  • fine-tuning에 대한 새로운 시야 제공
    • 논문에서는 BART가 몇 additional transformer layers위에 stack되는 새로운 machine translation 방법을 제시
    • 해당 레이어들은 foreign language를 noised English로 essentialy하게 translate하고자 train됨
      • BART의 propagation을 통해서, BART를 pre-trained target-side language model로서 쓰면서
      • 해당 approach는 strong back-translation MT baseline에서 성능 향상을 보임(WMT Romanian-English benchmark에서 1.1 BLEU)

ablation analysis

  • 위에서 언급한 효과들을 잘 이해하기 위해서 ablation analysis 진행
    • 이 때 다른 최근에 제안된 training objectives들도 활용함
  • 본 연구는 training objectives에 성능에 있어서 중요하다고 알려진, data, optimization parameters를 포함한 factor들을 control하는 것을 가능하게 함
  • 우리는 BART가 우리가 생각했던 task에서 가장 좋은 성과를 지속적으로 보여준다는 것을 발견함

7 Related Work

2017 Jun : Transformer
2018 Feb : ELMo
2018 Jun : OpenAI GPT
2018 Oct : BERT
2019 Feb : OpenAI GPT-2
2019 May : UniLM
2019 Jun : MASS
2019 Jun : XL-Net
2019 Oct : BART
2019 Oct : T5
2020 Jun : OpenAI GPT-3

    1. GPT
    • left-ward context만 models함 -> 몇 task에 문제가 됨
    1. ELMo
    • left-only와 right-only representation을 concatenate함
    • 그러나, 해당 특징들 사이의 interactions를 pre-train하지는 않음
  • 매우 큰 언어모델이 unsupervised multitask model로서 작용할 수 있음을 입증한 연구도 있었음

    1. BERT
    • masked language modelling을 소개함
      • pre-training이 left와 right context words 사이의 interactions를 학습하는 것을 가능하게 함
    • predictions가 auto-regressively하게 이루어지지 않아, generation task에서 효율성이 낮음
  • 3-1. 최근 연구

    • 1) training for longer, 2) tying parameters across layers, 3) masking spans instead of words 같은 방법들을 통해서 강한 performance를 얻을 수 있음이 입증
    1. UniLM
    • BERT를 masks의 ensemble로 fine-tuning함
      • 그들 중, 몇 개는 leftward context만을 허용함
      • 이로 인해 BART처럼 generative와 discriminative task에서 모두 쓰이는 것이 가능해짐

        + discriminative와 generative의 차이점

      • BART와의 차이점은, UniLM은 predictions를 conditionally independent하게 하고, BART는 autoregressive하게 한다는 것임
        • BART는 decoder가 uncorrupted한 context에서 훈련되기 때문에, pre-training과 generation task 간의 mismatch를 줄일 수 있음

          autoregressive의 경우 상황을 dependent하게 본다는 내용이 함축되어있다. 즉, BART의 경우 앞에서 예측한 것을 뒤에 예측할 때 쓰는 등의 방식을 활용하는데 UniLM의 경우 이러지 않는다는 뜻.

    1. MASS
    • BART와 가장 유사한 모델
    • input sentence(where a contiguous span of tokens is masked)는 missing tokens를 포함하는 sequence로 mapped됨
    • discriminative tasks에서는 BART보다 상대적으로 비효율적임
      • disjoint sets of tokens가 encoder와 decoder에 fed되어서
    1. XL-Net
    • masked tokens를 permuted order로 auto-regressively하게 예측하는 BERT의 확장 모델
      • predictions가 left, right context 모두에서 condition하는 것을 가능하게 함
        • 이와 달리 BART decoder는 pre-training 때 left-to-right하게 작동하면서 generation 동안 setting을 matching함

정리하면,
machine translation을 향상하기 위해서 pre-trained representations를 사용하는 몇 연구들이 있었다

  • 가장 큰 향상은, pre-training을 source와 target languages 모두에 하는 것
    • 하지만 이는 관심 있는 모든 languages를 pretraining하는 것을 필요로 한다는 한계가 있음
  • 다른 연구는 encoder가 pre-trained representations를 사용했을 때
    • 하지만 decoder에서의 gain이 부족하다는 한계

그래서 본 논문에서는 BART가 machine translation decoders를 향상하는데 쓰일 수 있음을 입증했다

2 Model

  • BART는 corrupted document를 원래 있었던 original document로 map하는 denoising autoencoder
  • sequence-to-sequence model로 시행
    • 1) corrupted text가 있는 bidirectional encoder와
    • 2) left-to-right autoregressive인 decoder

2.1 Architecture

  • standard한 sequence-to-sequence Transformer architecture를 따름
    • 이 떄, GPT의 다음 과정은 따르지 않음!
      • ReLU activation functions를 GeLUs로 바꾼 것
      • parameters를 N(0, 0.002)로 초기화 한 것
  • Base model
    • encoder와 decoder에서 6 layers(large model은 12개)
  • 아키텍처는 BERT와 유사한데
    • 차이점은,
      • 1) decoder의 각 레이어는 encoder의 final hidden layer에 대해서 추가적으로 cross-attention을 수행
      • 2) BERT는 word-prediction 이전에 추가적인 feed-forward network를 사용(BART는 해당 과정 생략)
    • 결론적으로 BART는 같은 사이즈일 때 BERT보다 10% 더 많은 parameters를 포함

2.2 Pre-training BART

  • 1) documents를 corrupting하고 2) decoder의 output과 original document 사이의 reconstruction loss를 optimizing하는 과정에서 훈련
  • 특정 noising 방식에만 적용 가능했던 이전의 denosing autoencoders와 달리, any type of document corruption이 가능
  • 그래서 논문에서는 이전 제시 방법과 새로운 방법들에 대해서 몇 실험을 진행함

이미지 출처

1 Token Masking

  • BERT와 동일
  • random tokens가 [MASK] elements와 함께 sample되고 replace됨

2 Token Deletion

  • input에서의 random tokens이 삭제됨
  • 이때 token masking과 달리, model은 어떤 위치가 missing inputs인지 결정해야 함

3 Text Infilling

  • 가장 성능 좋았던 방식
  • 포아송 분포에서 span length를 추출한 길이만큼의 text span을 sampling하고 각 token을 [MASK] token으로 대체
    • span: 긴 문장의 토큰들
    • 길이가 0인 경우, [MASK] token을 삽입하는 것과 동일
  • span으로부터 몇 개의 token이 사라졌는지를 예측하는 방법을 모델에게 가르치는 방식

4 Sentence Permutation
document는 full stops에 기반을 두어 문장 단위로 나누어지고, 해당 문장들을 랜덤 순서로 섞음

5 Document Rotation

  • 임의의 token을 뽑고, document를 해당 token으로 시작하도록 rotate
  • 해당 task는 모델에게 문장의 시작이 어디인지를 학습하도록 도움

3 Fine-tuning BART

BART에 의해서 만들어지는 representations는 downstream applications에서 여러 방법들로 쓰일 수 있다

BART를 다음 4가지 task에 어떻게 사용할 것인지를 다루는 부분

3.1 Sequence Classification Tasks

  • 사용 방법
    • 같은 input을 encoder, decoder에 넣어주고,
    • decoder token의 final hidden state는 새로운 multi-class linear classifier에 fed
  • 해당 방법은 BERT에서의 CLS token과 관련이 있다

    BERT에서의 CLS token은 BERT가 task를 수행하기 위해서 추가한 특별한 토큰

    • 하지만, BART의 경우 end에 additional token을 추가해서 -> decoder에서 해당 token을 위한 representation은 complete한 input으로부터 decoder states에 attend할 수 있다(Fig 3a)

3.2 Token Classification Tasks

  • 예) SQuAD를 위한 answer endpoint classification
    • complete document를 encoder와 decoder에 feed하고, decoder의 top hidden state를 각 단어를 위한 representation으로서 사용
    • representation은 token을 classify하기 위해서 사용

3.3 Sequence Generation Tasks

  • BART의 decoder가 autoregressive해서
    • sequence generation tasks(abstraction question answering, summarization)를 위해서 fine-tuned될 수 있음
      • 두 task 모두에서 -> information이 input으로부터 복사되고, 관련있는 denoising pre-training objective로 조작된다
  • 여기서 encoder input이 input sequence이고 decoder가 output을 autoregressively하게 생성.

3.4 Machine Translation

수행 목적과 과거 연구들의 한계

  • 목적: 영어로 번역하는 decoder의 성능을 높이기 위해서 다음 과정을 수행함
  • 이전 연구는 모델이 pre-trained encoders를 통합하는 과정을 통해서 향상할 수 있음을 보였다
    • 하지만 이를 통해서 얻는 이점은 decoder에서 한계가 많음.

논문의 해결책

  • BART model 전체(encoder와 decoder)를 single pretrained decoder로서 사용하기!
    • bitext에서 학습된 새로운 encoder parameters 집합을 더해서(Figure 3b)
  • 더 구체적으로,
    • encoder에서 embedding layer를 new randomly initialized encoder로 교체
    • 모델은 end-to-end로 훈련되는데, 새로운 encoder를 훈련할 때
      • foreign words를 input으로 map하고자 함
    • 새 encoder는 original BART model로부터 seperate vocabulary로 쓰일 수
    • 효용성을 검증하고자 freeze 과정을 통해 모델 결과를 확인

4 Comparing Pre-training Objectives

BART는 이전 연구보다 pre-training 동안에 더 넓은 범위의 noising 방법들을 사용했다
그래서 해당 부분에서는 그 방법들을 토대로, base size models(6 인코더, 6 디코더, hidden size: 768)를 사용해서 여러 옵션들과 비교하고 evaluate했다

4.1 Comparison Objectives

  • 많은 pre-training objectives가 제안되는 동안, 해당 모델들을 비교하는 것이 어려웠음
    • 왜냐하면, training data, training resources, architectural differences between models, fine-tuning procedures에서의 차이 때문에
  • 본 논문에서는 pre-training objective와 연관이 없는 것들의 차이를 최대한 통제하며 비교함

5가지의 모델에 대해서 차이를 통제하기 위해서 어떻게 설정하였는지에 대해서 상세하게 설명하고 있다

    1. Language Model
    • GPT와 유사하게 left-to-right Transformer language model을 훈련함
    • 해당 모델은 BART decoder와 동일함(cross-attention 빼고)
    1. Permuted Language Model
    • XLNet에 초점을 두어서 token의 1/6dmf sample하고 그들을 random order autoregressively로 generate함
    • 다른 모델과의 일관성을 위해, relative positional embeddings나 attention across segments from XLNet는 시행X
    1. Masked Language Model
    • BERT와 동일하게 15%의 token을 [MASK]로 바꾸고, original token을 독립적으로 예측하기 위해서 model을 훈련함
    1. Multitask Masked Language Model
    • UniLM에서 Masked Language model을 additional self-attnetion masks로 train함
    • self-attnetion mask는 랜덤으로 뽑힘
    1. Masked Seq-to-Seq
    • MASS에서 영감을 받아 50% token을 포함하는 span을 mask함
    • masked tokens를 예측하고자 sequence to sequence model을 훈련함

4.2 Tasks

모델에 대한 소개를 끝냈으니 수행한 여러 Task에 대해서도 자세히 소개한다

  • SQuAD
    • Wikipedia 문단에 대한 extractive question answering 태스크
    • 정답은 document context가 주어졌을 때 거기서 추출된 text spans
    • BERT와 비슷하게, BART에서는 question과 context를 concat한 것을 인코더의 입력으로 하고, 디코더를 통해 예측하도록 함.
    • 모델은 각 토큰의 시작과 끝 인덱스을 예측하는 분류기가 포함
  • MNLI
    • 하나의 문장이 다른 문장을 entail하는지 예측하는 bitext classification task
    • BART 모델은 EOS token이 추가된 두 개의 문장을 합치고, 이를 인코더와 디코더에 넣음
    • BERT와 달리, EOS 토큰의 표현이 문장 관계를 분류하는데 쓰임
  • ELI5
    • 긴 형식의 abstractive question answering task
    • 모델은 문제와 추가적인 문서를 concat한 것으로 조건으로 주어 답을 생성.
  • XSum
    • 함축된 요약을 생성하는 뉴스 요약 태스크
  • ConvAI2
    • 대화의 답변에 대한 generation 태스크
    • context와 persona(화자)를 조건으로 줌.
  • CNN/DM
    • 뉴스 요약 데이터셋
    • 요약본은 입력 문서와 밀접하게 연관있음

4.3 Results

모델들과 task들을 토대로 한 결과를 보고한다

1) Pre-training의 성능은 태스크별로 차이가 있음

  • Pre-training의 효율성은 태스크에 크게 의존
  • 예) Language Model의 경우 ELI5에 최고의 성능, SQuAD에서는 최악의 성능.

2) 토큰 마스킹은 중요함

  • Rotating Document나 Permuting Sentences 기반 Pre-training은 해당 목적 함수로만 훈련시켰을 때 성능이 좋지 않음.
  • 성공적인 방법은 Token Deletion이나 Token Masking, Self-Attention Mask를 사용하는 방법.
    • Token Deletion는 생성 태스크에서 Token Masking보다 더 좋은 성능을 보임.

3) Left-to-right 기반 언어 모델은 generation에 효과적임

  • Masked Language Model과 Permuted Language Model은 generation task(ConvAI2)에서 다른 것들보다 성능이 낮음
    • 이 두 모델은 사전학습 단계에서 left-to-right auto-regressive 언어 모델링을 적용하지 않은 모델들임.

4) SQuAD에서 양방향 인코더는 중요함.

  • 이전 연구에서 left-to-right 디코더가 SQuAD 태스크에서 성능이 안좋았음.
    • classification decisions에서 future context가 중요하기 때문
    • 하지만 BART는 bidirectional layers 수가 거의 절반인데도 유사한 성능을 거둠

5) 사전 학습 방법론 만이 중요한 요소는 아님

  • 우리의 Permuted Language Model은 기존 XLNet보다 성능이 낮았음.
    • 이 차이는 다른 architectural improvements를 포함 안해서
      • relative-position embeddigs나 segment-level recurrence 같은

6) Pure language models이 ELI5에서 최고의 성능

  • ELI5 dataset
    • BART를 썼을 때 다른 task보다 PPL이 높았다 + 더 나아가, BART가 아닌 다른 모델이 성능이 더 좋았던 generation task임
    • Pure language models이 ELI5에서 최고의 성능 => input에 의해서 덜 제약을 받았을 때 BART가 덜 효율적임을 입증

7) BART가 가장 일관성 있게 강력한 성능 달성

  • ELI5 제외, BART를 Text Infilling으로 학습한 모델이 모든 태스크에서 좋은 성능을 보임

5 Large-Scale Pre-training Experiments

최근 연구들은 downstream performance가 pre-training이 large batch size로 scale되었을 때 성능향상이 크다는 것을 입증했다

BART가 이 영역에서 얼마나 잘 수행하는지 보이고, downstream task를 위한 유용한 모델을 만들고자 RoBERTa 모델과 같은 scale에서 BART를 훈련했다

5.1 Experimental Setup

어떠한 환경에서 BART를 훈련했는지에 대한 설명이 기술되어있다

  • large model: 12 layer(encoder + decoder), hiddensize 1024
  • RoBERTa와 동일하게, batchsize 8000 으로 500000 step 훈련
  • documents들은 GPT-2와 같은 byte-pair encoding으로 토큰화됨
  • #4에 있는 결과에 기반해서, text infilling과 sentence permutation의 결합을 사용
    • 각 문서 토큰의 30%를 mask했고 모든 문장을 permute함
    • 비록 sentence permutation은 CNN/DM summarization dataset에서만 좋은 성능 향상을 보여주었지만,
      • 본 논문에서는 larger pre-trained models로 해당 task가 학습되기에 더 나을 것임을 가정함
    • 모델이 data에 더 잘 fit하게 하고자, 우리는 training의 마지막 10%를 dropout으로 못쓰게 함
    • 우리의 pre-training data는 160 gb의 news, books, stories, web text로 구성

5.2 Discriminative Tasks

BART를 최근 approaches들과 함께 SQuAD, GLUE tasks에서 비교한 결과

  • BART 다음 가장 우세한 baseline은 RoBERTA
  • 대다수의 task에서 차이 작게, 유사한 성능을 거둠
    • generation task에서의 BART의 성능 향상이 classification performance에 의한 희생에 따른 것이 아님을 입증함(=둘다 잘했기 때문에 classification 상대적으로 못해서 generation을 잘하게 된 것은 아니다 !!)

5.3 Generation Tasks

  • BART는 이 때 input으로부터 standard sequence-to-sequence model에서 fine-tuned되어서 output text가 됨
  • fine-tuning 동안에 label smoothed cross entropy loss 사용
    • smoothing parameter set = 0.1
  • generation 동안
    • beam size = 5
    • remove duplicated trigrams in beam search
    • tuned the model with min-len, max-len, length penalty on the validation set

Summarization

  • summarization에서의 SOTA와의 비교를 위해 두 dataset에서의 결과를 보고함
  • 1) CNN/DailyMail
    • source sentence를 닮는 경향이 있음
    • extractive model이 성능이 좋다
    • BART가 다 앞섬
  • 2) XSum
    • abstractive해서 extractive model이 오히려 못한다
    • BART가 이전 최고의 성과를 앞섬(6.0 ROUGE나 높음)

요약 관련 질적인 연구들은 #6에 나온다!

Dialogue

  • CONVAI2에서 dialogue response generation을 평가
    • BART가 two automated metrics에서 모두 앞섬

Abstractive QA

  • ELI5 dataset을 사용해 long free form answers를 생성하는 모델의 능력을 평가함
  • BART가 이전 최고 성과를 1.2 ROUGE-L로 앞지름

5.4 Translation

  • WMT16 Romanian-English에서도 성능을 평가

6 Qualitative Analysis

BART가 어떻게 우수한 성과를 거두었는지를 이해하고자 generation을 qualitatively하게 분석함

Table7 해석

  • BART에서 생성된 예제들
  • 예시들은 WikiNews article에서 가져옴
  • 결과
    • 유창하지만, 인풋에서 복사된 phrases가 적어서 꽤 추상적임
    • 더 나아가 generally factually accurate하고
    • supporting evidence를 input document를 통해서 잘 통합함(background knowledge(예. correctly completing names, or inferring that PG&E operates in California)와 함께)
    • 첫번째 예시의 경우, 'fish are protecting reefs from global warming'은 텍스트에서의 non-trivial inference를 요한다, 하지만 'the work was published in Science'라는 주장은 source에 의해서 support 받지 않았다

해당 사례를 통해 BART pretraining이 NLU와 NLG의 결합에서 학습되었음을 입증함

8 Conclusion

  • BART: pre-training approach
    • corrupted documents를 original로 map하는 것을 학습
    • discriminative tasks에서 RoBERTa와 유사한 성능
    • 수많은 text generation task에서는 SOTA
  • 후속 연구는 pre-training을 위한 corrupting documents에 대한 새로운 방법이 되어야 한다
    • 아마 그들을 specific end task로 맞추는

BART 후기

좋은 점

  • related work가 후반부에 있는 이유는 모르겠지만, 여러 모델들을 잘 정리해주어서 좋다
    • 2019년도에 BART를 읽었더라면 소개해주는 모델이 더 생소했을 것 같아서 더 도움이 되었을 것 같다
  • related work의 후반부에서 BART가 왜 decoder에 초점을 두었는지 이유들이 명쾌하게 설명된 것 같아서 좋았다
    • 이전 연구들이 encoder를 발전시켰는데 decoder를 더 발전시킬 수 있을 것 같으니까 ! 해보겠다 ! 이런 느낌
  • objective와 task를 여러 개 설정해서 BART의 우수성을 입증하고 7가지의 발견점을 찾아낸 부분이 좋았다
    • 논문을 위해 정말 꼼꼼하게 분석한 것 같은 !!

아쉬운 점

  • Tab3,4,5는 본문에서 여기 테이블 봐!! 에 대한 언급이 없었던 점

코드

modeling_bart.py

  • BARTEncoder
class BartEncoder(BartPretrainedModel):
    def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        embed_dim = config.d_model
        self.padding_idx = config.pad_token_id
        self.max_source_positions = config.max_position_embeddings
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)

        if embed_tokens is not None:
            self.embed_tokens.weight = embed_tokens.weight

        self.embed_positions = BartLearnedPositionalEmbedding(
            config.max_position_embeddings,
            embed_dim,
        )
        self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
        self.layernorm_embedding = nn.LayerNorm(embed_dim)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutput]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input = input_ids
            input_ids = input_ids.view(-1, input_ids.shape[-1])
        elif inputs_embeds is not None:
            input = inputs_embeds[:, :, -1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

        embed_pos = self.embed_positions(input)
        embed_pos = embed_pos.to(inputs_embeds.device)

        hidden_states = inputs_embeds + embed_pos
        hidden_states = self.layernorm_embedding(hidden_states)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        # expand attention_mask
        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        # check if head_mask has a correct number of layers specified if desired
        if head_mask is not None:
            if head_mask.size()[0] != (len(self.layers)):
                raise ValueError(
                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
                    f" {head_mask.size()[0]}."
                )

        for idx, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability < self.layerdrop):  # skip the layer
                layer_outputs = (None, None)
            else:
                if self.gradient_checkpointing and self.training:

                    def create_custom_forward(module):
                        def custom_forward(*inputs):
                            return module(*inputs, output_attentions)

                        return custom_forward

                    layer_outputs = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(encoder_layer),
                        hidden_states,
                        attention_mask,
                        (head_mask[idx] if head_mask is not None else None),
                    )
                else:
                    layer_outputs = encoder_layer(
                        hidden_states,
                        attention_mask,
                        layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                        output_attentions=output_attentions,
                    )

                hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
        )
  • BartDecoder
class BartDecoder(BartPretrainedModel):
    def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)
        self.dropout = config.dropout
        self.layerdrop = config.decoder_layerdrop
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0

        self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)

        if embed_tokens is not None:
            self.embed_tokens.weight = embed_tokens.weight

        self.embed_positions = BartLearnedPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
        )
        self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
        self.layernorm_embedding = nn.LayerNorm(config.d_model)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            combined_attention_mask = _make_causal_mask(
                input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
            ).to(inputs_embeds.device)

        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
                inputs_embeds.device
            )
            combined_attention_mask = (
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
            )

        return combined_attention_mask

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            input = input_ids
            input_shape = input.shape
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            input = inputs_embeds[:, :, -1]
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        # past_key_values_length
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input) * self.embed_scale

        attention_mask = self._prepare_decoder_attention_mask(
            attention_mask, input_shape, inputs_embeds, past_key_values_length
        )

        # expand encoder attention mask
        if encoder_hidden_states is not None and encoder_attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])

        # embed positions
        positions = self.embed_positions(input, past_key_values_length)
        positions = positions.to(inputs_embeds.device)

        hidden_states = inputs_embeds + positions
        hidden_states = self.layernorm_embedding(hidden_states)

        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
        next_decoder_cache = () if use_cache else None

        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
        for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
            if attn_mask is not None:
                if attn_mask.size()[0] != (len(self.layers)):
                    raise ValueError(
                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
                        f" {head_mask.size()[0]}."
                    )

        for idx, decoder_layer in enumerate(self.layers):
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability < self.layerdrop):
                continue

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, output_attentions, use_cache)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(decoder_layer),
                    hidden_states,
                    attention_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    head_mask[idx] if head_mask is not None else None,
                    cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
                    None,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                    cross_attn_layer_head_mask=(
                        cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
                    ),
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )
            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

                if encoder_hidden_states is not None:
                    all_cross_attentions += (layer_outputs[2],)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            cross_attentions=all_cross_attentions,
        )
  • BartDecoderLayer
class BartDecoderLayer(nn.Module):
    def __init__(self, config: BartConfig):
        super().__init__()
        self.embed_dim = config.d_model

        self.self_attn = BartAttention(
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
        )
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout

        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.encoder_attn = BartAttention(
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = True,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states

        # Self Attention
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # add present self-attn cache to positions 1,2 of present_key_value tuple
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            past_key_value=self_attn_past_key_value,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # Cross-Attention Block
        cross_attn_present_key_value = None
        cross_attn_weights = None
        if encoder_hidden_states is not None:
            residual = hidden_states

            # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
                hidden_states=hidden_states,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
                layer_head_mask=cross_attn_layer_head_mask,
                past_key_value=cross_attn_past_key_value,
                output_attentions=output_attentions,
            )
            hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
            hidden_states = residual + hidden_states
            hidden_states = self.encoder_attn_layer_norm(hidden_states)

            # add cross-attn to positions 3,4 of present_key_value tuple
            present_key_value = present_key_value + cross_attn_present_key_value

        # Fully Connected
        residual = hidden_states
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
        hidden_states = self.fc2(hidden_states)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        hidden_states = self.final_layer_norm(hidden_states)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights, cross_attn_weights)

        if use_cache:
            outputs += (present_key_value,)

        return outputs

4개의 댓글

comment-user-thumbnail
2023년 3월 22일

기존 모델들의 decoder를 사용하는데 단점이 있다는 걸 새롭게 지적해준 점에서 통찰력있는 논문 같습니다!!! 다양한 denosing 방법도 알게되어서 유익하였습니다! 감사합니다~

답글 달기
comment-user-thumbnail
2023년 3월 22일

꼭 읽어봐야지 했던 논문인데, 넘나 잘 이해되게 설명해주셔서 감사합니다 ㅎㅎ BART가 이해와 생성 모두 잘해야 하는 분야인 Summarization 분야의 SOTA인걸로 알고 있는데, Encoder의 Comprehension 능력과 Decoder의 Generation 능력이 잘 버무려져서 나온 결과라 생각되네요 ! 다양한 document corruption 방식 모두 실험한 것 또한 인상적이었습니다 !

답글 달기
comment-user-thumbnail
2023년 3월 29일

계속 모듈을 갖다 쓰기만 했는데, 이렇게 논문을 잘 정리해주셔서 감사합니다! 상세한 정리 좋아요!

답글 달기
comment-user-thumbnail
2023년 3월 31일

BART에 대해서 자세히 알지 못했었는데, 자세히 꼼꼼하게 설명해주셔서 이해가 잘 되었습니다 !! 감사합니다 ~_~

답글 달기