이로부터 우리는 query selection을 잘못한다면 model weight이 엄청 커질 수 있음을 알 수 있다. (upper bound가 없어지므로, 비슷한 쿼리에 다른 specificity를 할당할 수 있음) 여기서 bottleneck이 발생하게 되는데
weight이 커질 수록 최적화 속도 감소 (vanishing softmax derivative)
overfit, overly sensitive final model
실험 결과 norm을 찍어봐도 selective attention의 attention weight이 sparse하면서도 더 작은 attention임이 확인 가능하다.
같은 attention head도 더 빨리 특정 쿼리의 distinct/sparse attention을 파악하는 것이 가능해짐을 확인하고자 추가적인 실험을 진행한다.
확인을 위해 specificity를 기준으로 class를 구분하여 undirected graph를 만들어 둔다. 여기서 stochastic matrix 를 얻을 수 있을 것이다. 이러한 를 추정하기 위해 token prediction 실험을 진행한다.
결과적으로 cross entropy distance를 체크하여 salmonella가 왜 bacteria보다 낮은 temperature를 가지게 되는지를 설명할 수 있다.
추가적으로 이렇게 approximate한 결과와 간의 차를 err_map을 통해 확인한다. 확인해보면 SSA의 결과가 더 낮은 값을 가지며, 자연히 SSA가 fewer neighbors에 대해서 낮은 temperature를 assign한다는 직관과 일치한다.
좀 더 정리해서 Proposition으로 이어가보자.
attention dilution을 해결하기 위해 query position이 필요하다. 그러나 softmax score은 upper bound가 존재한다.
Feature imbalance setting에서 optimal temperature scaling이 없으면 풀 수 없는 task에 대해 생각해보자.
feature imbalance: 중요하지 않지만 자주 나오는 토큰 (ex. is) 가 중요하지만 덜 자주 나온 토큰에 비해 집중받는 경우
Imbalanced token setup
여기서 proposition 2가 등장한다.
proof for proposition 2
따라서 n이 증가할수록 불필요한 토큰을 가지고 있게 되어 relation 의 역인 k_n 의 증가로 이어진다. 이 power law를 따른다고 하면 temperature sacling rule은 아래와 같다.
따라서, position-aware scaling rule은 이러한 특성을 반영하여 logarithmic 하게 설계되었다.
Value embedding은 단순히 선형변환이므로, 어더한 토큰의 기여도는 attention score에 기반한 weighted sum이다. value temeprature scaling은 non-linear scalar weighting function으로 작용하며, temperature을 조정하여 각 토큰의 영향력을 조절할 수 있도록 한다.
이러한 장점을 Denoising task에서 확인할 수 있다.
Task setup
따라서, 이를 해결하기 위해 여러가지 nonlinear 방법을 활용했다.
Standard benchmarks
Passkey retrieval
Ablations
token-aware, position-aware
weight-sharing, feature-based
different-function
fixed attention row, let
이 때 sparsity와 temperature scaling 간의 관계는 분명하다. 예를 들어, top entry의 temperature가 감소한다면, entropy는 증가할 것이다. power-law assumption을 기반으로, 어떠한 attention이 2개의 값을 fraction of larger attention score로 파악한다고 해보자.
- c+: score attained by salient token
- gamma : score advantage of salient tokens over rest of tokens
- pow : fraction of salient tokens
Lemma
proof of Lemma 1