Transformer는 NLP, CV, and audio processing에서 great successes를 보여주고 있다.
core components로, softmax attention은 long-range dependencies를 효과적으로 capture하지만,
sequence length에 따라 space and time complexity가 quadratic 증가하여 scale-up에 제약이 있다.
이를 해결하기 위해 kernel methods가 자주 사용되어 softmax operator를 approximating하지만,
이 과정에서 발생하는 approximation errors로 인해,
vanilla softmax attention에 비해 crucial performance drops이 발생한다.
본 논문에서는 COSFORMER라는 Linear Transformer를 제안한다.
이 model은 standard Transformer와 비교했을 때 comparable or better accuracy를 달성할 수 있으며,
Causal Attention과 Cross Attention 모두에서 효과적이다.
COSFORMER는 Softmax attention의 two key properties에 기반한다:
COSFORMER는 이러한 특성을 유지하면서도,
a linear operator와 cosine-based distance re-weighting mechanism을 활용한다.
(문제 & 기존 연구)
(기존 연구의 문제점)
(motivation)
(Ours Method)
이 section에서는 COSFORMER라는 우리의 linear transformer에 대한 technique details을 제공한다.
COSFORMER의 key insight는 non-decomposable non-linear softmax operation을,
decomposable non-linear re-weighting mechanism으로 대체하는 것이다.
Our model은 causal attention과 cross attention 모두에 적용이 가능하며,
input sequence length에 대해 a linear time and space complexity를 가지므로 long-range dependency를 효과적으로 modeling하는 strong capacity를 보인다.
linear attentions에 대한 key는 decomposable simlarity function 을 찾는 것이다.
여러 연구들이 있었지만.... (skip)
본 논문에서는 softmax를 대체하는 a new replacement of softax를 제안하며,
이는 다양한 task에서 softmax attention과 comparable or better performance를 달성함과 동시에,
linear space and time complexity를 갖는다는 장점이 있다.
vanilla transformer architecture에서는 일 때, softmax operation이 적용되어 attention matrix 에 대해 row-wise normalization이 수행된다. (Eq. 2)
즉, input sequence의 각 element가 다른 모든 elements와 맺는 relations을 normalize함으로써, contextual information의 weighted aggregation을 계산한다.
그러나 softmax attention이 실험적으로는 좋은 성능을 보임에도 불구하고, 그것의 핵심적이고 필수적인 특성이 무엇인지에 대해서는 the original transformer paper and follow-up works에서도 명확하게 규정되지 않았다.
이 연구에서, 우리는 softmax operation이 performance에 중요한 역할을 하는 two key properties를 식별했다:
위의 assumption을 validate하기 위해,
우리는 Table 1에 제시된 바와 같이 preliminary studies를 설계했다.
먼저, non-negativity의 중요성을 검증하기 위해 equation 3에서 the function 의 three instantiations (변형)을 비교한다:
다음으로, non linear re-weighting의 효과를 보여주기 위해,
우리는 softmax 연산 없이 만 사용하는 model과 softmax operations을 포함한 model을 비교한다.
Table 1에서 보이듯이, 가 와 보다 superior results를 보이는 것은
the benefit of retaining non-negative values를 보여준다.
우리의 conjecture는 similarity matrices에서 positive values만 유지함으로써, model은 negative correlation을 가지는 features를 무시하게 되며,
결과적으로 irrelevant contextual information을 agregating하는 것을 효과적으로 피할 수 있다. (?)
또한 와 softmax를 비교한 결과, softmax re-weighting을 사용하는 models이 더 빠르게 converge하고,
downstream task에 better generalize되는 것을 관찰했다.
이는 softmax normalization이 correlated pairs를 증폭시켜, useful patterns을 식별하는 데 유용하기 때문일 것이다.
softmax attention으로부터 도입되는 the non-linear re-weight mechanism은
the distribution of the attention weights에 집중할 수 있게 하고,
그렇기 때문에 training process가 stabilize하다.
우리는 또한 어떤 경우에는 far-away (멀리 떨어진) connections을 punish하여 locality를 강화한다는 것을 경험적으로 발견했다.
실제로 이러한 locality bias- 즉, 대부분의 contextual dependencies가 neighboring tokens으로부터 온다는 특성은 downstream NLP tasks에서 일반적으로 관찰된다.
이는 Figure 3 (1)에서도 확인할 수 있다.
위 assumption을 바탕으로, softmax의 second property를 만족시키기 위해 필요한 것은 attention matrix에 recency bias (최근 정보 bias)를 도입할 수 있는 a decomposable re-weighting mechanism이다.
여기서 우리는 우리의 목적에 부합하는 cos-based re-weighting mechanism을 제안한다:
구체적으로, Eq 6를 결합함으로써, the model with cosine re-weighting은 다음과 같이 정의한다:
re-weighting이 갖는 의미는 뭘까...?
Ptolemy's theorem에 의하여, 우리는 이 formulation을 다음과 같이 decompose한다:
- softmax-based similarity function이 갖는 특징
- attention matrix are non-negative
- non-linear re-weighting scheme
- ReLU-based linear similarity function은 softmax-based similarity function이 갖는 특징의 대부분을 활용할 수 있음.
- attention matrix are non-negative
- non-linear, but not yet re-weighting scheme...
- CosFormer는 ReLU-based linear simlarity function의 2.에서 re-weighting scheme이 없는 것을 보완하기 위해 cos-based re-weighting mechanism을 제안함.
- attention matrix are non-negative
- non-linear, -based re-weighting scheme
그래서 cos-based nonlinear re-weighting을 어떻게 했는데?
인접한 token에는 1에 가까운 weighting, 멀리 떨어져 있는 token에는 0에 가까운 weighting.