XLNet

AI CAU_DSLAB·2022년 2월 23일
0

XLNet

  • XLNet
    • XLNet은 2019년 당시 대부분의 NLP Task에서 SOTA를 달성했던 BERT를 큰 차이로 밀어낸 모델
    • 핵심 Contribution
      • GPT로 대표되는 Auto-Regression모델과, BERT로 대표되는 Auto-Encoder 모델의 장점을 합한 Generalized AutoRegressive Pre-training Model
      • Permutation language model objective와, Two-stream attention Mechanism 제안
  • Introduction
    • Pre-training을 통해 얻어진 Representation을 직접적으로 활용(Word2Vec, ELMo)하거나 Pre-trained Model을 DownStream Task에 대해 Fine-tuning하는 방법(BERT,GPT)등이 성과를 보여주는 중
    • NLP DownStream Task를 위한 Pre-Training 대표적인 Objective
      • Auto Regressive(AR) [ELMo, GPT, RNNLM]
        • 일반적인 Language Model의 학습 방법으로 이전 Token을 예측하는 문제를 해결
        • LM의 Objective
          • input sequence :x=(x1,x2,...,xT)input \space sequence \space :x=(x_1,x_2,...,x_T)
          • forward likelihood :p(x)=t=1Tp(xtx<t)forward \space likelihood \space: p(x)= \prod^T_{t=1}p(x_t|x_{<t})
          • backward likelihood:p(x)=t=T1p(xtx>t)backward \space likelihood : p(x)=\prod^1_{t=T}p(x_t|x_{>t})
          • training objective(forward):maxlogpθ(x)=maxΣt=1Tlogp(xtx<t)training\space objective(forward): max logp_\theta(x)=max \Sigma^T_{t=1}logp(x_t|x_{<t})
          • maxΣt=1Tlogp(xtx<t)=maxΣt=1tlogexp(htheta(x1:t1)Te(xt))Σxexp(hθ(x1:t1)Te(x))max \Sigma^T_{t=1}logp(x_t|x_{<t})=max\Sigma^t_{t=1}log{exp(h_{theta}(x_{1:t-1})^Te(x_t))\over \Sigma_{x'}exp(h_\theta (x_{1:t-1})^Te(x'))}
            • hθ(x1:t1)h_\theta(x_{1:t-1})은 신경망에 의해 학습된 Context Representation
            • e(xt)e(x_t)는 x의 임베딩
        • Likelihood & Objective
          • Input Sequence의 Likelihood는 forward / backward 방향의 Conditional Probability들의 곱으로 나타냄
          • 모델은 이러한 Conditional Distribution을 Objective로 학습
        • AR은 방향성이 정해져야 하므로, 한 쪽 정보만을 이용할 수 있음
          • 양방향 문맥을 활용해 문장에 대해 깊이 이해하기 어려움
          • ELMo의 경우 양방향을 이용하지만, 각 방향에 대해 독립적으로 학습된 모델을 이용하므로 얕은 이해만 가능
      • Auto Encoding(AE)
        • AR과 달리 AE를 기반으로한 사전학습은 분포를 추정하지 않고, 손상된 데이터를 원래 데이터로 재건하는 것을 목표로 함
        • AE는 주어진 Input을 그대로 예측하는 문제를 품
          • 주로 차원 축소 등을 목적으로 이용
        • Denosing Auto Encoder는 Noise가 섞인 input을 원래 input으로 예측하는 문제를 품
          • BERT의 방식도 일종의 DAE로 볼 수 있음
            • Sequence Token을 일정 확률로 [MASK]로 변환하여 원래 토큰으로 복원하는 방식
        • LM의 Objective
          • input sequence:xˉ=(x1,x2,...,xT)input \space sequence:\bar x = (x_1,x_2,...,x_T)
          • corrupted input:x^=(x1,[MASK],...,xT)corrupted \space input: \hat x =(x_1,[MASK],...,x_T)
          • likelihood:p(xˉx^)t=1tp(xtx^)likelihood : p(\bar x| \hat x) \approx \prod^t_{t=1}p(x_t|\hat x)
          • training objective:maxlogp(xˉx^)=maxΣt=1Tmtlogp(xtx^)training \space objective : max logp(\bar x| \hat x)= max \Sigma^T_{t=1}m_tlogp(x_t|\hat x)
          • maxΣt=1Tmtlogp(xtx^)=Σt=1Tmtlogexp(Hθ(x^)tTe(xt))Σxexp(Hθ(x^)tTe(x))max \Sigma^T_{t=1}m_tlogp(x_t|\hat x)=\Sigma^T_{t=1}m_tlog{exp(H_\theta(\hat x)_t^Te(x_t))\over \Sigma_{x'}exp(H_\theta(\hat x)_t^Te(x'))}
            • x^\hat x = Corrupted Version
            • xˉ\bar x = Masked Token
        • 일반적인 DAE의 Likelihood p(xˉx^)p(\bar x | \hat x)와 이를 Maximize하는 Objective를 이용
        • But. 계산 과정에서 두 가지 차이점이 있음
          • independent assumption : 주어진 input sequence에 대해 각 [MASK] token의 정답 token이 등장할 확률은 독립이 아니지만, 독립으로 가정
            • 독립이므로, 각 확률의 곱으로 나타낼 수 있음
          • xtx_t가 [MASK] token일 경우, mt=1m_t=1, 나머지 경우에는 mt=0m_t =0
            • [MASK] token에 대해서만 prediction 진행
            • mtm_t를 둠으로써, [MASK] 토큰만 예측하는 Objective는 DAE의 Objective(input + noise에 대해 input을 복원, 즉 노이즈의 위치와 관계 없이 전체를 복원)와 다르지만 노이즈를 원래의 Input으로 복원하는 개념상 유사
        • AR과 달리 AE는 특정 [MASK] Token을 맞추기 위해 양방향 정보를 이용할 수 있음
        • But. Independent assumption으로 모든 [MASK] token이 독립적으로 예측됨으로써, 이들 사이의 dependency를 학습할 수 없다는 단점이 있음
          • 또한, 노이즈(마스크 토큰) 자체는 실제 fine-tuning 과정에 등장하지 않으므로, pre-training과 fine-tuning간의 불일치 발생
  • Proposed Method : XLNet
    • 위 두 가지의 장점을 살리고 단점을 극복하기 위해 세가지의 새로운 방식을 제안
      • Objective
      • for Objective Target-Aware Representation
      • Two-stream self-Attention 구조
    • Objective : Permutation Language Modeling
      • 길이 T의 Sequence X=[x1,x2,...,xT]X=[x_1,x_2,...,x_T]가 주어졌을 때, 시퀀스를 나열할 수 있는 모든 순서의 집합 (ZT)(Z_T) 순열은 [1,2,3,...,T],[2,3,4,...,T],...,[T,T1,...,1][1,2,3,...,T],[2,3,4,...,T],...,[T,T-1,...,1]등 총 T!T!개 만들 수 있음
      • 새로운 objective는 다음 식과 같이 위 집합(ZT)(Z_T)에 속해있는 모든 순서들을 고려하여 AR 방식으로 모델링을 진행
        • input sequence:x=(x1,x2,...,xT)input \space sequence : x=(x_1,x_2,...,x_T)
        • likelihood;EzZT[t=1Tpθ(xztxz<t)]likelihood ;E_{z\sim Z_T}[\prod^T_{t=1}p_\theta(x_{z_t}|x_{z<t})]
        • training objective:maxθ EzZT[Σt=1Tlog pθ(xztxz<t)]training \space objective :max_\theta \space E_{z \sim Z_T}[\Sigma^T_{t=1}log \space p_\theta(x_{z_t}|x_{z<t})]
      • 각 순서에 대한 log likelihood 기대값을 최대화
        • 기존의 AR모델링은 해당 Objective의 순열 중 한 가지 경우 원래의 순서([1,2,3,..,T][1,2,3,..,T])만을 고려
      • 즉, 본 방법은 input Sequence index(순서)의 모든 Permutation을 고려한 AR방식 이용
        • ZTZ_T를 기준으로, 여러 순열이 있을테지만, 하나의 텍스트를 모든 순열의 순서에 대해서 고려하는 것은 불가능하기에, 특정 순열의 순서(z)(z)를 샘플링 하고 해당 순서에 대한 pθ(x)p_\theta(x)를 분해
          • 많은 데이터를 학습하면 Parameter θ\theta는 학습하는 동안 모든 순서에 대해 공유되므로 모든 순서를 고려한다고 생각할 수 있음
          • 이 방식으로 모델은 자연스럽게 어떤 근사 없이 양방향 컨텍스트를 볼수 있게 됨
        • 여기서 주의해야 할 점은, Model은 Input(시퀀스)에 대해 순서를 섞는 것이 아니고, p(x)p(x)에 대한 조건부 확률들의 곱으로 분리할 때만 순서를 섞음
          • 이로인해, 모델이 기존 시퀀스 토큰들의 절대적인 위치를 알 수 있게 학습
    • Architecture : Two-Stream Self-Attention for Target-Aware Representation
      • Target-Aware Representation
        • 위에서 제안한 Objective를 기존 Transformer에 적용하면 문제가 발생
          • 일반적으로 파라미터 θ\theta를 갖는 모델로 다음 토큰 분포 pθ(Xztxz<t)p_\theta(X_{z_t}|x_{z<t})를 예측하기 위해 모델의 최종 hidden state(hθ(xz<t)h_\theta(x_{z<t}))와 Softmax 사용
            • pθ(Xzt=xxz<x)=exp(e(x)Thθ(xz<t))Σxexp(e(x)Thθ(xz<t))p_\theta(X_{z_t}=x|x_{z<x})={exp(e(x)^Th_\theta(x_{z<t}))\over \Sigma_{x'}exp(e(x')^Th_{\theta}(x_{z<t}))}
          • 이 때, transformer의 hθ(xz<t)h_\theta(x_{z<t})는 예측되는 토큰의 위치에 관계없이 일정한 값을 가짐
          • 기존 AR모델링에서는 Context가 고정되면 예측할 토큰의 위치가 다음 시점(t)로 고정되어 문제가 발생하지 않지만, 위 Objective의 경우 순열된 순서를 고려하기에 주어진 Context(xz<tx_{z<t})가 고정되더라도 예측할 토큰의 위치가 고정되지 않기 때문에 예측 위치에 대한 정보가 추가적으로 필요함
          • 이에 대한 해결책으로 모델의 Input에 토큰의 위치 정보 (zt)(z_t)를 추가적으로 제공
            • hθ(xz<t)>gθ(xz<t,zt)h_\theta(x_{z<t}) -> g_\theta(x_{z<t},z_t)
            • pθ(Xzt=xxz<t)=exp(e(x)Tgθ(xz<t,zt))Σxexp(e(x)Tgθ(xz<t,zt))p_\theta(X_{z_t}=x|x_{z<t})={exp(e(x)^Tg_\theta(x_{z<t},z_t))\over \Sigma_{x'}exp(e(x')^Tg_{\theta}(x_{z<t},z_t))}
      • Two-Stream Self-Attention
        • hθ(xz<t)>gθ(xz<t,zt)h_\theta(x_{z<t}) -> g_\theta(x_{z<t},z_t)에 대한 변경
          • time step ztz_t는 주변 Context(xz<tx_{z<t})와 Attention으로 정보를 축정해 나가는 방식 제안
          • 두 가지 Constraint
            • 특정 시점 t에서 target position ztz_t의 token xztx_{zt}를 예측하기 위해, hidden representation g(xz<t,zt)g(x_{z<t},z_t)는 t 시점 이전의 context 정보 xz<tx_{z<t}와 target position 정보 ztz_t만을 이용 해야. 함
              • 단어 자체의 정보를 알게 된다면 Cheating
            • 특정 시점 t 이후 j(>t)j(>t)에 해당하는 xzjx_{zj}를 예측하기 위해 hidden representation g(xz<t,zt)g(x_{z<t},z_t)가 t시점의 content인 xztx_{z_t}를 인코딩 해야 함
          • 위 조건들은 특정 시점에서 하나의 hidden state를 인코딩하는 기존 Transformer구조에서는 작동하지 않음
        • Architecture
          • 모델의 기본 구조는 위와 같음
          • 두가지 hidden representation 제안
            • Query Representation
              • gzt(m)<Attention(Q=gztm1,KV=hz<tm1;θ)g^{(m)}_{z_t}<- Attention(Q=g^{m-1}_{z_t},KV=h^{m-1}_{z<t};\theta)
              • 첫 번째 Constraint를 해결하기 위해 제안
              • 현재 시점을 제외한 이전 시점 token들의 Content와 현재 시점의 위치 정보를 이용하여 계산되는 Reprentation
              • 마지막 층의 Query Representation을 이용하여 현재 position의 토큰을 예측하는 pre-training objective 계산
              • 첫 층의 Query는 훈련 가능한 random value로 초기화하고 위 식과같은 방식으로 각 layer의 State 계산
            • Content Representation
              • hzt(m)<Attention(Q=hztm1,KV=hztm1;θ)h^{(m)}_{z_t}<- Attention(Q=h^{m-1}_{z_t},KV=h^{m-1}_{z \le t};\theta)
              • 현재 시점 및 이전 시점 Token들의 Content를 이용하여 계산되는 Representation
              • 바닐라 transformer의 Hidden State와 동일한 역할
              • 첫 층의 Content Stream은 해당 위치 Token의 Word embedding으로 초기화
              • 각 층의 State는 위 식과 같이 계산
              • t index 이후 Token의 State들은 마스킹하여 계산
        • Partial Prediction
          • 위 방식은 사실 매우 느리고 Optimize가 어려움
            • 이를 해결하기 위해,특정 순서에서 마지막 몇 개의 예측만 이용하는 방식을 사용
          • 위 식을 최대화 하는 방식
      • Incorporating Ideas from Transformer-XL
        • XLNet은 긴 문장의 처리를 위해 Transformer-XL에서 사용 된 두 가지 테크닉을 차용
          • Relative Positional Encoding
            • 바닐라 Transformer는 Self-attention을 기반으로하고, CNN이나 RNN과 달리 단어들의 상대적, 절대적 위치 정보를 직접적으로 모델링하지 않음
            • 대신 Input에 단어의 절대적인 위치에 대한 Representation (Absolute positional encoding)을 추가하는 방식으로 순서에 대한 모델링을 함
              • But. 위 방식은 하나의 Segment 내에서는 위치에 대한 의미를 표현할 수 있지만, Transformer-XL과 같이 여러 Segment에 대해 Recurrent 모델링을 하는 경우 문제가 있음
              • 아래 식은 Segment-level의 Recurrence의 식
                • hτ+1=f(hτ,Esτ+1+U1:L)h_{\tau+1}=f(h_\tau,E_{s_\tau+1}+U_{1:L})
                  • τ+1\tau+1번 째 segment의 문장 sτ+1s_{\tau+1}에 대한 hidden state를 구하는 식
                • hτ=f(hτ1,Esτ+U1:L)h_\tau=f(h_{\tau-1},E_{s_\tau}+U_{1:L})
                  • τ\tau번째 segment의 문장 sτs_\tau에 대한 hidden state를 구하는 식
              • 위 식에서 τ\tau는 segment의 순서를 의미
              • EsτE_{s_\tau}는 input 문장 sτs_\tau의 word embedding
              • ff는 transformation function
            • 위 식에서 사용하는 input을 주목하면 word embedding EE의 경우 segment순서에 맞춰서 알맞게 들어 갔으나, U1:LU_{1:L}의 경우 τ\tau번 째 Segment에 속한 단어들의 position이 τ+1\tau+1번째 segment의 단어들의 position보다 앞에 있지만, 둘 다 같은 위치를 표현하는 U1:LU_{1:L}을 사용
            • 즉, 위 두 식에서 τ\tau번째 Segment의 첫 번째 단어와, τ+1\tau+1번째 segment의 첫 번째 단어를 위치상 같다고 인식
            • 이러한 문제를 해결하기 위해서 Transformer-XL과, XLNet은 input-level이 아닌 self-attention mechanism에서 relative positional encoding이라는 단어 간의 상대적 위치 정보를 모델링하는 기법을 제안
              • Attention score in standard Transformer
              • Attention score with Relative Positional Encoding
              • Term (b)와 (d)에서 기존 absolute embedding UjU_j를 relative positional embedding RijR_{i-j}로 대체
                • RR은 학습가능하는 파라미터가 아닌 sinusoid encoding matrix
              • Term (c)와 (d)에서 UiTWqTU^T_iW^T_q를 각각 uT,vTu^T,v^T로 대체
                • Query vector가 모든 query position에 대해 같기 때문에 다른 단어들에 대한 attention bias가 query position에 상관없이 동일하게 유지되어야함
              • WkW_kWk,E,Wk,RW_{k,E},W_{k,R}로 분리
                • Content 기반의 key vector와 location 기반의 key vector를 만들어 내기 위함
              • Attention score with Relative Positional Encoding의 Term (a)는 content를 처리, Term (b)는 content에 의존한 positional bias를 잡아냄 Term (c) global content bias Term (d)는 global positional bias 인코딩
          • Segment Recurrence Mechanism
            • XLNet은 긴 문장에 대해서 여러 segment로 분리하고 이에 대해 recurrent하게 모델링을 함
            • Transformer-XL에서 제안된 segment-level recurrence를 XLNet에 적용하기 위해 2가지 포인트에 주목
              • 어떻게 permutation setting에 recurrence mechanism을 적용할 것인지
              • 이전 segment로 부터 얻어진 hidden state를 재사용할 수 있게 하는 것
            • x~=s1:T,x=sT+1:2T\tilde x=s_{1:T}, x=s_{T+1:2T} 일 때, z~=[1,2,...,T],z=[T+1,...,2T]\tilde z=[1,2,...,T], z=[T+1,...,2T]의 순열이라면, z~\tilde z를 기반으로 첫 번째 segment에 대한 처리를 하고, 각 layer m으로 부터 얻어진 content representation h~(m)\tilde h^{(m)}을 caching
            • 이후, 두 번째 Segment에 대한 계산은 아래의 수식과 같이 나타냄
              • hzt(m)<Attention(Q=hzt(m1),KV=[h~(m1),hzt(m1)];θ)h^{(m)}_{z_t}<-Attention(Q=h_{z_t}^{(m-1)},KV=[\tilde h^{(m-1)},h^{(m-1)}_{z \le t}];\theta)
            • [.,.][.,.]은 concatenation. h~(m)\tilde h^{(m)}를 이전 segment 처리에서 계산하고 나면, z~\tilde z와 독립적으로 현재 segment (z)에 대한 attention update가 이뤄짐
              • 이를 통해 과거 segment에 대한 factorization order를 고려하지 않고, memory의 caching과 reusing 가능
            • Query stream gzt(m)g^{(m)}_{z_t}에 대해서도 같은 방식으로 계산 가능
      • Modeling Multiple Segments
        • XLNet을 어떻게 multiple segment에 대해 AR하게 모델링하고 학습시키는지
        • Pre-traning
          • input은 BERT와 유사하게 [A, SEP, B, SEP, CLS]의 형태로 주어짐
            • SEP CLS는 스페셜 token
            • Segment A, Segment B에 들어가게 될 두 개의 문장(segment)를 랜덤으로 샘플링
            • 이후, 두 Segment를 하나의 문장으로 Concat하여 Permutation을 수행
        • Relative Segment Encoding
          • BERT는 Absolute Segment Embedding을 사용하지만, XLNet은 Relative Position Encoding과 비슷한 원리로 Relative Segment Encoding 적용
            • 전체 Sequence에서 주어진 position i,ji,j가 같은 Segment라면, sij=s+s_{ij}=s_+ 아니면 sij=ss_{ij}=s_-를 사용 s+,ss_+, s_-는 Attention head에 존재하는 학습가능한 파라미터임
          • 위 방식을 통한 이점
            • Relative encoding의 inductive bias가 generalization을 향상
            • 둘 이상의 segment를 갖는 fine-tuning task에 대한 가능성을 열어줌
      • Discussion
        • BERT와의 비교
          • 예를 들어, [New, York, is, a, city]라는 문장(Sequence of words)가 주어졌을 때, BERT와 XLNet 모두 예측할 토큰으로 [New, York] 2개를 선택하여 logp(NewYorkis a city)logp(New York|is \space a \space city)를 maximize 해야 하는 상황
          • 이때 BERT와 XLNet의 objective는 아래와 같음
          • JBERT=logp(Newis a city)+logp(Yorkis a city)J_{BERT}=logp(New|is \space a \space city)+logp(York|is \space a \space city)
          • JXLNet=logp(Newis a city)+logp(YorkNew,is a city)J_{XLNet}=logp(New|is\space a \space city)+ logp(York|New, is \space a \space city)
          • 즉, XLNet에서는 각 word간의 Sequence의 순서에 관계 없이 Dependency를 찾을 수 있음
        • LM과의 비교
          • GPT와 같은 AR LM류 모델들은 과거에 대한 Dependency만을 고려할 수 있음
          • 이는 QA와 같이 span extraction을 하는 문제에서 취약할 수 있음
          • 하지만 XLNet에서는 가능
profile
중앙대학교 Data Science Lab입니다.

0개의 댓글