[논문 리뷰]Transformer - XL : Attentive Language Models Beyond a Fixed - Length Context

찬호·2024년 1월 3일
0

preface

1. Introduction

언어 모델의 특징은 문장을 다루기 때문에, 문장의 길이가 모델 성능을 좌지우지 하는 경우가 매우 많습니다. 그렇기에 긴 문장을 다룰줄 아는 것은 굉장히 중요합니다. 대표적인 언어 모델로는 RNN, LSTM 등이 사용되었는데, 기울기 소실 문제 등으로 긴 문장들을 다루기 힘들어서 recurrent를 기반으로 하는 모델은 현재 사용이 거의 되지 않습니다. 이 당시에는 Transformer 의 파격적인 등장으로 Transformer로 모든 구조가 변환되었습니다.

하지만 Transformer도 문제점이 있었는데, 임베딩을 하고, 문장을 학습하는 과정에서 모델이 받아드릴 수 있는 문자의 개수가 제한적이라는 점입니다. 기존의 트랜스포머 논문과 다르게, Al-Rfou et al. (2018) 에서는 fix_length_vector 때문에 문장들을 segment 단위로 나눠서 학습에 사용을 하였습니다.이는 최대한 정보 유지를 위함입니다. (기존의 transformer 구조는 그냥 max_length 기준으로 자름.)

예를 들어, 한 언어모델이 받아드릴 수 있는 문자의 수가 10개라면, ‘The quick brown fox jumps over the lazy dog’ 와 같은 문장은 ["The quick “, "brown fox "], [ "jumps over", " the lazy ", "dog"] 로 나눌 수 있습니다. 나누어진 토큰들은 10개이하의 문자들로 구성된 것입니다. 여기서 문제는 10개이하로 나누었을 때, fox와 jumps가 관련이 있을 수도 있는데, 세그먼트 단위로 다시 문장을 구성하니 여기서 무시를 해버리게 됩니다. 저자는 이러한 문제점을 context fragmentation으로 지칭합니다.

저자는 이러한 문제를 해결하기 위해 Transformer-XL 구조를 제안합니다. 이 구조는 Self-attention 구조에 recurrent를 대입한 것입니다. 기존의 정보들을 다 유지하면서, 학습을 시킬 수 있습니다. 그리고 relative positional encoding을 사용하여 temporal confusion을 최소화 시킬 수 있습니다.

저자는 Language model에 대한 발전을 언급하고 있으며, 특히 긴 문장을 다루는 논문들을 소개했습니다. 그 중 처음으로 Transformer 구조를 사용한 work 이며, long - term dependency 해결 해준 language model이라고 소개하고 있습니다.

3. Model

token x=(x1,,xT)\mathbf{x} = (x_{1}, \dots, x_{T}) 가 주어졌을 때, 언어 모델은 T+1T+1 일때의 가장 높은 확률값을 구하는 것입니다. 위 논문의 저자도 위와 같은 방식을 참고하였고, context x<t\mathbf{x}_{<t}가 fixed_size hidden state로 바뀌며, 이 값들을 기준으로 다음 값들을 예측한다고 말하였습니다.

3.1 Vanila Transformer Langauge Models

트랜스포머 구조를 잘 학습하기 위해서는 긴 문장을 어떻게 encode 할지를 정하는 것입니다. 제한된 환경에서 트랜스포머에 전체 문장을 넣어서 학습을 하면 되겠지만, 학습 환경과 제한된 resource로는 불가능합니다. 그래서 전체 corpus를 작은 segment 단위로 나눠서 학습을 진행하는 것입니다. 저자는 이를 vanilla model 이라고 부릅니다.

vanilla model는 문장 하나를 segment 2개로 나눠서 학습을 하고 있습니다. 여기서 2가지의 문제점을 제기할 수 있는데, 첫 번째로 가장 의존도가 높은 길이의 문장 길이는 segment에 의해서 제한이 되기 때문에 모든 정보를 다 가져오지 못한다는 점, 두 번째로 padding을 통해서 문장이나 단어를 보충할 수는 있지만, 이는 context fragmentation 문제를 제기한다는 점 입니다.

모델 학습 과정에서 vanilla model은 모든 segment로 학습을 하지만, prediction 과정에서는 마지막 layer 만을 사용해서 inference를 하게 됩니다. segment 단위로 결과를 도출하며, context fragmentation issue를 해결해줍니다. 하지만 이 과정은 많은 계산량을 요구합니다.ㅇ

3.2 Segment-Level Recurrence with State Reuse

저자는 위와 같은 문제를 해결하기 위해서 트랜스포머 구조에 recurrence 메커니즘을 도입하였습니다. hidden state sequence들은 fixed 되고, cached 되어서 다음 segment 학습에 도움을 줍니다. 이는 과거의 모든 정보들을 사용하게 해주고, 특히 longer term dependency를 유지시키고, context fragmentation을 해결해줍니다.

길이가 L인 연속적 segment 두 개를 sτ=[xτ,1,,xτ,L]s_{\tau} = [x_{\tau, 1}, \dots, x_{\tau,L}] and sτ+1=[xτ+1,1,,xτ+1,L]s_{\tau+1} = [x_{\tau+1, 1}, \dots, x_{\tau+1,L}] 로구 성하였습니다. sτs_{\tau}의 hidden state는 hτnh_{\tau}^n으로 쓰이며, 식은 다음과 같습니다.

hτ+1n1~=[SG(hτn1)hτ+1n1]qτ+1n,kτ+1n,vτ+1n=hτ+1n1WqT,hτ+1n1~WkT,hτ+1~WvThτ+1n=TransformerLayer(qτ+1n,kτ+1n,vτ+1n)\widetilde{h_{\tau+1}^{n-1}} = [\mathbf{SG(h_{\tau}^{n-1}})\odot \mathbf{h_{\tau+1}}^{n-1}]\\ \mathbf{q_{\tau+1}}^n,\mathbf{k_{\tau+1}}^n, \mathbf{v_{\tau+1}}^n = \mathbf{h_{\tau+1}^{n-1}}\mathbf{W_{q}^T}, \mathbf{\tilde{h_{\tau+1}^{n-1}}}\mathbf{W_{k}^{T}}, \mathbf{\tilde{h_{\tau+1}}}\mathbf{W_{v}^T}\\ \mathbf{h_{\tau+1}^n} = \mathbf{Transformer-Layer(\mathbf{q_{\tau+1}}^n,\mathbf{k_{\tau+1}}^n, \mathbf{v_{\tau+1}}^n)}

여기서 SG 함수는 stop gradient를 의미하며, [huhv][h_{u}\odot h_{v}]의 notation을 사용하여 두 개의 hidden sequence를 합쳐서 구성한 것이다.

기존의 트랜스포머와의 확연한 차이는 self attention의 구성인 key값인 kτ+1n\mathbf{k_{\tau+1}}^n와 value 값인 vτ+1n\mathbf{v_{\tau+1}}^n

에 차이가 있습니다. 이 두 값은 hτ+1n1~\widetilde{h_{\tau+1}^{n-1}}으로 계산이 되며, hτn1h_{\tau}^{n-1} 은 이전의 segment 값들을 저장합니다. 그리하여 context 값은 2개의 segment 로 계산이 됩니다.

하지만 hτn1h_{\tau}^{n-1}hτ+1nh_{\tau+1}^{n}사이의 recurrent한 dependency는 segment 별로 한 layer씩 옮겨가며, 결과적으로 layer와 segment의 길이로 인해서 large possible dependency length가 linear 하게 증가하게 되는 것이다. 위 방법은 BPTT와 굉장히 유사하지만 마지막 layer 만을 계산하는데 쓰는 것이 아니라 모든 layer를 사용한다는 점에서 차이가 있다.

저자가 말하는 다른 장점으로는 계산 속도가 굉장히 빠르다는 점입니다. 특히 이전의 segment 들을 representation 한 것들은 모든 값들을 계산하는 것이 아니라 바로 이전의 값만을 계산하는데 사용하기 때문에 굉장히 빠릅니다.

3.3 Relatvie Positional Encodings

hidden state를 사용하는 과정에서 positional information 이전 과정에는 적용할 수 없다고 합니다. 이를 만약에 기본 트랜스포머 구조에 적용을 하게 되면, hτ+1=f(htau,Esτ+1+U1:L)hτ=f(hτ1,Esτ+U1:L)\mathbf{h_{\tau+1}} = f(\mathbf{h_{tau}, E_{s_{\tau+1}}+U_{1:L}})\\ \mathbf{h}_{\tau} = f(\mathbf{h}_{\tau-1}, \mathbf{E_{s}}_{\tau} + \mathbf{U}_{1:L})로 볼 수 있는데, 여기서 Esτ\mathbf{E_{s}}_{\tau}Esτ+1\mathbf{E_{s}}_{\tau+1}의 구조가 같은 포지션 정보인 U1:L\mathbf{U}_{1:L}에 영향을 받게 됩니다. 즉, 모델은 xτ,jx_{\tau,j}xτ+1,jx_{\tau+1,j} 사이를 구분하지 못하게 되며, 포지션을 읽지 못하게 됩니다.

위와 같은 문제를 해결하기 위해서 hidden state의 상황에서 the relative positional information 를 encode 하는 것이다. 전통적인 Transformer 모델에서는 각 단어 또는 토큰의 절대적인 위치 정보를 embedding vector에 포함시켜 처리한다. 하지만 Transformer-XL에서는 상대적인 위치 정보를 이용하는데, 각 쿼리 벡터 qτ,iq_{\tau, i}가 키 벡터 kτik_{\tau≤i}에 만날 때, 키 벡터들의 절대 위치를 알 필요 없이, 각 키 벡터와 쿼리 벡터 사이의 상대적 거리, 즉 iji−j만 알면 되는 것이다.

구체적으로 말해보자면, 상대 거리 정보를 각 레이어의 attention 점수에 동적으로 주입함으로써, 쿼리 벡터는 xτ,jx_{\tau, j}xτ+1,jx_{\tau+1, j} 둘 사이의 거리에 기반하여 쉽게 구별할 수 있습니다. 이는 state reuse 메커니즘을 가능하게 하면서도 절대 위치 정보를 상대 거리로부터 재귀적으로 복구할 수 있으므로, 시간적 정보를 잃지 않게 됩니다.

필자는 위를 쓰면서도 이 말이 무슨 말인지 모르다가 예시를 보고 이해를 했기에 예시를 들어봅니다. 예를 들어, 문장: "The cat sat on the mat." 이 문장을 Transformer 모델이 처리한다고 가정해보겠습니다. 예를 들어, "cat"이라는 단어가 쿼리 벡터 qcatq_{cat}로, "the", "sat", "on", "the", "mat"이 각각 키 벡터 kthek_{the}, ksatk_{sat}, konk_{on}, kthek_{the}, kmatk_{mat}로 변환됩니다.

절대 위치 인코딩 (전통적인 Transformer 방식)

  • 전통적인 Transformer에서는 각 단어가 문장에서 몇 번째 위치에 있는지 (예: "cat"은 2번째, "sat"은 3번째)를 절대 위치로 인코딩하여 이 정보를 사용합니다.
  • 이와 같은 방식은 각 단어의 정확한 위치를 모델이 알 수 있도록 해줍니다.

상대 위치 인코딩 (Transformer-XL 방식)

  • 상대 위치 인코딩에서는 "cat"이 "the", "sat", "on" 등과 얼마나 멀리 떨어져 있는지를 상대적으로 인코딩합니다. 예를 들어, "cat"에서 "the"까지의 거리는 -1 (하나 이전), "sat"까지의 거리는 -2 등입니다.
  • Transformer-XL에서는 이러한 상대적 거리를 인코딩하는 행렬 RR을 사용하여, "cat"과 다른 단어들 사이의 거리를 모델이 이해할 수 있게 합니다.
  • 위와 같은 경우에는, "cat"은 "the"와 1만큼 떨어져 있고, "sat"과는 2만큼 떨어져 있음을 인식합니다. 따라서 "cat"이 "the" 다음에 오는지 "sat" 다음에 오는지를 상대적인 위치로 파악할 수 있습니다.

기존의 트랜스포머 구조에서는 같은 segment에 쿼리와 키 사이의 어텐션 스코어는 아래와 같이 분해할 수 있습니다.

Ai,jabs=ExiTWqTWkExj+ExjTWqTWkUj+UiTWqTWkExj+UiTWqTWkUj\mathbf{A}_{i,j}^{abs} = \mathbf{E}_{x_{i}}^T\mathbf{W}_{q}^T\mathbf{W}_{k}\mathbf{E}_{x_{j}} + \mathbf{E}_{x_{j}}^T\mathbf{W}_{q}^T\mathbf{W}_{k}\mathbf{U}_{j} + \mathbf{U}_{i}^T\mathbf{W}_{q}^T\mathbf{W}_{k}\mathbf{E}_{x_{j}} + \mathbf{U}_{i}^T\mathbf{W}_{q}^T\mathbf{W}_{k}\mathbf{U}_{j}

앞에 첫번째 식은 쿼리와 키 사이의 내적을 어텐션에, 두번째는 쿼리와 키 사이의 상호작용 효과를 어텐션에, 세번째는 쿼리의 상대적인 위치정보를 어텐션에 포함, 네번째는 쿼리와 키의 상대적 위치 정보만을 이용하여 어텐션에 넣는 것으로 해석할 수 있습니다.

저자는 이와 같은 구조에서 다음과 바꿨습니다.

Ai,jabs=ExiTWqTWkExj+ExjTWqTWkRij+uTWqTWk,EExj+vTWqTWk,RRij\mathbf{A}_{i,j}^{abs} = \mathbf{E}_{x_{i}}^T\mathbf{W}_{q}^T\mathbf{W}_{k}\mathbf{E}_{x_{j}} + \mathbf{E}_{x_{j}}^T\mathbf{W}_{q}^T\mathbf{W}_{k}\mathbf{R}_{i-j} + u^T\mathbf{W}_{q}^T\mathbf{W}_{k,E}\mathbf{E}_{x_{j}} + v^T\mathbf{W}_{q}^T\mathbf{W}_{k,R}\mathbf{R}_{i-j}

두번째와 네번째의 절대적 위치 정보인 Uj\mathbf{U}_{j}를 상대적인 정보 구조로 바꾸고 학습 가능한 parameter u로 바꾸었는데, 이는 attentive bias가 다른 단어들에게 같게 유지됨으로써, 어느 쿼리 포지션에 없이 유지하고자 하였고, 마찬가지로 v도 똑같은 이유로 추가를 하였습니다. 마지막으로 weight matrix를 2개로 나누어 content에 집중된 key 값과 위치 정보에 집중된 key 값으로 나누고자 하였습니다.

그리하여 직관적으로, 첫번째 항은 컨텐츠 정보를, 두번째 항은 컨텐츠에 의존하는 포지션 bias를, 세번째는 전체적인 컨텐츠 bias를, 네번째는 글로벌 정보를 encode하는 방식으로 해석할 수 있습니다.

그리하여 위와 같은 구조를 조합해서 Transformer XL을 만들었습니다.

h^rn1=[SG(mrn1)hrn1]qrn,krn,vrn=hrn1Wqn,hrn1Wkn,hrn1WvnAr,i,jn=qr,inkr,jn+qr,inWk,RnRijAr,i,jn=+uTkr,jn+vTWk,RnRijarn=Masked-Softmax(Arn)vrnorn=LayerNorm(Linear(arn)+hrn1)hrn=Positionwise-Feed-Forward(orn)\hat{h}^{n-1}_r = \left[ SG(m^{n-1}_r) \circ h^{n-1}_r \right] \\q^n_r, k^n_r, v^n_r = h^{n-1}_r W^n_q, h^{n-1}_r W^n_k, h^{n-1}_r W^n_v \\A^n_{r,i,j} = q^n_{r,i} k^n_{r,j} + q^n_{r,i} W^n_{k,R} R_{i-j} \\\phantom{A^n_{r,i,j} =} + u^T k^n_{r,j} + v^T W^n_{k,R} R_{i-j} \\a^n_r = \text{Masked-Softmax}(A^n_r) v^n_r \\o^n_r = \text{LayerNorm}(\text{Linear}(a^n_r) + h^{n-1}_r) \\h^n_r = \text{Positionwise-Feed-Forward}(o^n_r)

4. Experiments

4.1 Main Results

여러 데이터셋과 모델을 비교해본 결과, Transformer XL이 가장 좋은 성능을 보여주었으며, 어느 모델에서든지 높은 성능을 보여주고 있다.

4.2 Ablation study

저자는 두 가지 실험을 추가로 하였는데, recurrence 메커니즘과 새로운 포지셔녈 인코딩에 관한 실험을 하였다.

먼저 WikiText-103 데이터를 사용하여 의존성을 판단하였다.

여기서 AI-Rfou와 Vaswani et al 구조는 absolute한 것으로, 사용이 되었는데, 절대적 위치 인코딩은 half loss에서만 반응이 잘 되었고, 이는 즉, 긴 문장에 가면 갈 수록 recurrence 메커니즘과 상대적 인코딩 구조가 좋은 퍼포먼스를 보여주었습니다.

두 번째 연구는 긴 텍스트를 분석할 때, context fragmentation의 문제를 어떻게 해결할 것인지에 대한 연구입니다. 저자는 long dependency가 필요하지 않은 데이터셋을 고르고, 실험을 진행하였습니다. 결과는 상대적 위치 정보가 절대적인 위치 정보를 이긴 결과를 볼 수 있습니다.

4.3 Relative Effective Context Length

ECL은 문맥 길이를 늘렸을 때 어느 정도 이상의 성능 향상이 있을 최대 길이를 측정하는 방법이었지만, 모델 간 공정한 비교가 어렵다는 한계가 있었습니다. 그리하여 저자는 RECL 이라는 새로운 방법을 제안하였습니다. RECL은 여러 모델을 비교할 수 있도록 동일한 기준을 공유하며, 긴 문맥의 이득을 단문맥 모델 대비 상대적인 개선으로 측정합니다. 또한, RECL은 'r' 파라미터를 도입하여 상위 r% 어려운 예시들에 대한 비교를 제한합니다

결과적으로 Transformer-XL은 평균 900단어 길이의 의존성을 모델링할 수 있으며, RECL은 순환 신경망과 기존의 Transformer 모델보다 각각 80%, 450% 더 긴 것으로 나타났습니다

5. Conclusions

Transformer XL 구조는 매우 높은 perplexity 결과를 보여주고, RNN와 트랜스포머 보다 더 높은 장기 의존성, 더 빠른 속도를 보여주었습니다.

한줄로 요약하면?

Transformer XL 구조는 이전의 정보를 활용하는 segment recurrent mechanism과 상대적 위치 인코딩을 통합한 것으로 위치 정보를 긴 문장에 특화된 Transformer 구조이다.

profile
그냥 끄적여 보는 논문..

0개의 댓글