현재까지 발표된 linear complexity를 가지는 self-attention에 대해서 정리를 해보려고 한다.
Linformer
Linear Transformer
Performer
1. Linformer
먼저 Linformer에 대해서 설명하면, Linformer 저자들은 self-attention matrix의 정보가 k개의 singular value로 reconstruct될 수 있다고 주장한다. 즉 low-rank로 근사될 수 있다고 실험을 통해 주장했고 SVD(Singular Value Decomposition)을 통해 attention matrix의 dimension을 낮추며 dot product 자체의 연산량을 줄일 수 있다고 한다. (실제 실험을 통해 MLM task에서 content matrix자체가 few largest singular value로 recover되었음.)
그러나 모든 sample에 대해서 SVD로 근사하는 것은 매우 inefficient하기 때문에 linear layer를 통해서 key와 value matrix를 low dimension k으로 projection했다.
Complexity:
O(lk)
2. Linear Transformer
Linear Transformer는 softmax 함수에 집중하여 연산량을 감소시켰다. 기존 softmax 함수를 포함한 self attention 식은 다음과 같다.
A(x)=A′=softmax(DQK⊤)V
이때 특정 i번째 sequence에서의 식은 다음과 같다.
A′=∑j=1Nsim(Qi,Kj)∑j=1Nsim(Qi,Kj)Vj
여기서 sim함수는 유사도를 계산하는 유사도 함수로 볼 수 있고 본 저자들은 분수 형태로, 즉 다른 토큰과 비교해서 j번째 토큰이 상대적으로 얼마나 유사한지 측정하기 위해서 삽입되는 분모가 필수가 아니라고 주장했다.
본 저자들은 따라서 유사도 함수 sim을 일종에 kernel function으로 생각했다.
sim(x,y)=sim(y,x):R2×F=exp(Dxy⊤)
유사도 함수가 두 벡터(query, key)를 input으로 scalar를 출력으로 하기 때문에 kernel function으로 볼 수 있고 x,y를 입력으로 고차원에서 mapping하는 과정을 내적하는 과정으로 보는 것이다. 여기서 하나의 제약 조건은 kernel function의 값이 양수가 되어야 한다는 것이다. 그렇다면 attention 연산의 식은 아래와 같이 변형해서 쓸 수 있다.
A′=∑j=1Nϕ(Qi,Kj)∑j=1Nϕ(Qi,Kj)Vj
그리고 결합법칙을 통해 아래와 같이 쓸 수 있다.
A′=ϕ(Qi)∑j=1Nϕ(Kj)⊤ϕ(Qi)∑j=1Nϕ(Kj)⊤Vj
Summation 항에서 query가 빠졌기 때문에 complexity자체가 linear하게 바뀐다. 먼저 한번만 ∑j=1Nϕ(Kj)Vj⊤ 와 ∑j=1Nϕ(Kj)를 계산하고 query와의 유사도를 구하는데 사용하면 되기 때문이다.
Complexity
O(lc)
c : feature map dimension
3. Performer
기존 attention 연산과정에서 softmax함수를 통과하고나면 기존 query와 key요소로 분해하는 것은 불가능하다. 그러나 attention matrix를 기존 query와 key에 대한 random nonlinear functions의 product로 분해하는 것은 가능하다. 즉 기존 Q,K,V∈RL×D이고 attention matrix가 A∈RL×L일 때 kernel function을 통해 Q′,K′∈RL×r로 근사할 수 있다.
Performer에서는 base로 Random Feature Attention을 사용했는데 이는 softmax 함수를 근사하기 위해서 random feature method를 통해 complexity를 감소시켰다. 또한 저자들은 향상된 방법을 통한 RFA를 통해 kernel approximation의 error를 감소시켰다. RFA의 main therom은 아래와 같다. ϕ:Rd→R2D가 비선형 transformation일 때,
또한 random feature map으로 trigonometric function을 사용하는 것은 unstable하기 때문에 positive random feature를 사용하는 것이 좋다고 하며 Orthogonal random features를 사용하여 variance를 줄일 수 있다고 저자들은 밝혔다.