Longformer

장준원·2021년 11월 5일

1. Introduction

3. Longformer

기존 Transformer들의 Self-Attention 연산은 Input Sequence n에 대해 O(n2)O(n^2)의 시간 및 공간 복잡도를 가지고 있습니다. 논문의 연구진들은 서로 관련이 있는 Token끼리만 Attention할 수 있도록 'Attention Pattern'을 지정해줌으로써 Full-Attention의 연산량을 희소화하려고 하였습니다. 본인들의 방법론으로 O(n)O(n)의 시간 및 공간 복잡도로 낮출 수 있다고 합니다.

3-1. Attention Pattern

  • Sliding Window

    이미지 출처: https://blog.agolo.com/understanding-longformers-sliding-window-attention-mechanism-f5d61048a907

    이미지 출처: https://blog.agolo.com/understanding-longformers-sliding-window-attention-mechanism-f5d61048a907

하나의 Token이 양 옆으로 1/2w만큼만 Attention을 할 수 있도록 제한을 두는 방법입니다. 이렇게 되면 가장 아래에 있는 layer에 대해서 하나의 Token이 같는 receptive field는 w이지만, layer을 쌓을수록 recepticve field가 커져서 가장 위에 있는 layer는 모든 다른 token들과의 관계를 학습할 수 있습니다. Sliding window 기법을 쓰면 O(n×w),where w is window sizeO(n \times w), where \ w \ is \ window \ size의 시간 및 공간 복잡도로 낮출 수 있다고 합니다.

  • Dilated Sliding Window

    이미지 출처: https://blog.agolo.com/understanding-longformers-sliding-window-attention-mechanism-f5d61048a907

w의 사이즈가 작게 되면, full-sequence를 보기 위한 receptive field를 만드려면 여러 layer를 쌓아야 합니다. 마찬가지로 연산량이 증가하는 구조입니다. 따라서 저자들은 window size는 유지하되, window의 각 칸마다 masking을 두어서 attention에 활용되는 연산량은 유지하면서 각 token의 receptive field는 넓게 가져가는 Dilated Sliding Widow 기법을 제시합니다. 이러면 더 적은 Layer의 수로 모든 다른 token들을 커버하는 receptive field를 만들 수 있죠. 하지만, 이러면 당연하게도 가장 낮은 layer에서는 주변의 token을 정확히 다 보지 못해 정보가 attention이 유의미하게 local context를 파악하지 못하죠. 따라서 저자들은 가장 낮은 layer 몇개에 대해서는 Sliding Window를 적용하고, 이후에는 Dilated Sliding Window를 적용한다거나 각 Layer의 특정 head는 Sliding Window를 나머지 head에 대해서는 Dilated Sliding Window 적용해 적은 연산량을 유지하면서 전반적인 성능 향상을 이룰 수 있다고 합니다. Dilated Sliding window 기법을 쓰면 O(n×w×d),where d is dilation sizeO(n \times w\times d), where \ d \ is \ dilation \ size 의 시간 및 공간 복잡도로 낮출 수 있다고 합니다.

  • Global Attention

    이미지 출처: https://arxiv.org/abs/2004.05150

하지만 모델이 해결하려는 Task마다 모든 Token에 global한 attention이 걸려야 하는 경우도 있습니다. Classification 같은 경우에는 전체 문장의 정보를 [CLS] Token에 담고, QA같은 경우에는 Question의 token들이 모든 document와 attention을 하면서 정답 span을 찾아냅니다. 따라서 저자들은 해결하는 문제마다 전체 Token과의 관련성을 파악해야 하는 Classification의 [CLS]나 QA의 Question의 token들은 기존과 동일하게 모든 Token들과의 Attention을 주는 Global Attention을 도입합니다. 해당 Token들의 수는 전체 길에 nn에 비해서 상대적으로 작으므로 O(n)O(n)의 시간 및 공간 복잡도를 유지할 수 있다고 합니다.

  • Linear Projections for Global Attention

저자들인 Sliding Window Attention과 Global Attention을 서로 다른 Wk,Wq,WvW_{k}, W_{q}, W_{v}로 projection해서 연산을 수행했다고 합니다.

3-2. Implementation

저자들은 Pytorch/Tensorflow에 Sliding Window Attention이나 Dilated Sliding Window Attention을 수행하기 위한 banded matrix multiplication이 구현되어 있지 않아 직접 구현했다고 합니다. 각 window를 loop으로 돌면서 처리하는 loop 방법 (non-vectorized 방법이므로 시간이 오래 걸림), 직접 각 window별로 chunk를 나누어서 attention을 하는 chunk, TVM을 활용해 직접 cuda kernel을 최적화한 cuda 방법이 있습니다. 제안되는 실행 방법론 모두 seqeunce length가 증가함에 따라 공간복잡도는 Full-Attention에 비해서 떨어지지만, loop 방법의 경우 시간 복잡도는 증가하는 확인할 수 있습니다.

이미지 출처:https://arxiv.org/abs/2004.05150

4. Autoregressive Language Model

저자들은 본인들이 제시한 Longformer Attention 기법을 Autoregressive Language Model에도 활용했습니다. Appendix B를 보면 Transforemer-XL을 기반으로 아래의 실험을 진행했다고 합니다.

4.1 Attention Pattern

AR 모델의 경우 낮은 layer에는 작은 window size를 적용하고, layer가 올라갈수록 window size를 키웠다고 합니다. 이를 통해 낮은 layer에서는 local context 정보를, 높은 layer에서는 entire context 정보를 보도록 유도했다고 합니다. 낮은 layer의 경우에는 dilated window를 적용하지 않아서 최대한 local context 정보를 활용ㅎ하도록 하였고, 높은 layer에서는 2개의 head에만 dilation 크기를 증가했다고 합니다.

4.2 Experiment Setup

본 연구에서는 character-level LM을 기반으로 실험을 진행했습니다. 저자들은 모델이 longer context를 보기 전에 local context에 대해서 많은 gradient update가 필요하다는 것을 발견하고, 훈련 단계를 5단계로 나누고 sequence length를 2048에서 23,040으로, windowe size는 2배씩 증가하면서, learning rate는 1/2씩 감소하면서 훈련을 진행했다고 합니다.

4.2.1 Test Result

저자들이 제시한 Lonformer-base model은 text8enwik8 dataset에 대해서 비슷한 파라미터를 가진 모델들 중 가장 좋은 성능을 보였다고 합니다. 해당 실험들이 Large Model로 진행하기에는 비용이 커서 enwik8 dataset에 대해서만 Large Model을 평가했고, SOTA 모델들에 비해서 절반 이하의 파라미터로 필적할만한 성과를 거두었다고 합니다.

이미지 출처: https://arxiv.org/abs/2004.05150

4.2.2 Ablation Study

저자들은 text8과 dataset에 대해서 window size 변화에 따른 성능 변화를 검증해보았습니다. 실험 결과, layer 수가 증가할수록 window size를키우는게 가장 좋은 성과를 거두었다고 합니다. 또한 Dilation을 안주는 것보다 2개 head에 대해서 주었을 때 성능이 좋았다고 합니다.

이미지 출처: https://arxiv.org/abs/2004.05150

5. Pretraining and Finetuning

6. Tasks

7. Longformer-Encoder-Decoder

Bert와 같은 Encoder-only 모델이 classification과 관련된 NLP문제를 푸는데 특화되어있지만, summarization과 같은 문제를 풀 때는 BART와 T5같은 encoder-decoder model이 더 효율적입니다. 저자들은 기존 1024개의 Token만을 Encoder를 받을 수 있는 BART 모델의 encoder에 Longformer 구조를 적용하고자 하였습니다. 저자들은 pre-trained BART의 파라미터를 가져와서 Encoder에 Longformer의 sliding window + global attention을 적용하였습니다. window size는 1024로 두었고, 첫 <\s> token에만 global attention을 주었다고 합니다. Encoder의 Input을 16K로 맞춰주기 위해 BART의 positional embedding을 16개로 복사해서 positional embedding으로 활용했다고 합니다. Decoder의 경우에는 이전과 그대로 encoder와의 cross-attention, decoder 자체의 left-only attention을 활용했다고 합니다.

이미지 출처: https://arxiv.org/abs/2004.05150

arXiv summarization에 대해서 실험한 결과, 저자들은 본인들이 제시한 Longformer가 비슷한 시기에 제시되었던 Bigbird에 비해서 더 긴 길이의 token을 처리할 수 있고 (Bigbird는 4K) 추가적인 pre-training 없이 더 좋은 성능을 보였다고 주장합니다. 또한 Figure3에서 보시는 것처럼 처리하는 encoder의 token 수가 많아질수록 ROUGE Score가 향상되는 것을 통해 긴 input 처리의 효용성을 증명합니다.

profile
공부하는 학부생

0개의 댓글