Attention Heads of Large Language Models: A Survey

ingeol·2024년 12월 31일
0

논문리뷰

목록 보기
63/63

Introduction

사람의 생각 process에 영감을 받은 4개의 stage framework를 소개한다: Knowledge Recalling, In-Context Identification, Latent Reasoning, and Expression Preparation. attention head 의 기능을 4가지로 구분하고자함. 추가적으로 special heads에 대해 modeling-Free 와 Modeling-Required method로 나눠 구분한다.

black-box 모델의 internal reasoning process 분석이 다양하게 진행되고 있고, 이때 attention head는 reasoning process에 중요한 역할을 한다.

해당 논문에서는 4개의 contribution에 집중한다.

  • Focus on the latest research: BERT와 같은 옛날 모델 보다 최근 모델인 GPT, LLaMA와 같은 모델에 집중한다.
  • An innovative four-stage framework for LLM reasoning: cognitive neuroscience로 부터 영감을 받아 인간의 사고과정을 4개로 분류하고 이를 LLM reasoning에 적용한다.
  • Detailed categorization of attention heads: 분류한 4개의 카테고리에 attention head의 역할을 분류한다.
  • Clear summarization of experimental methods: model dependency 관점을 바탕으로 attention head를 발견하기 위한 실험방법을 말해준다.

Background

Mathematical representation of LLMs

LLM decoder-only 모델의 attention 과정을 수식으로 나타내면 다음과 같습니다:

  1. 입력 임베딩:
    X=(x1,x2,...,xn)\mathbf{X} = (x_1, x_2, ..., x_n)
  2. Query, Key, Value 행렬 생성:
    Q=XWQ\mathbf{Q} = \mathbf{X}\mathbf{W}_Q
    K=XWK\mathbf{K} = \mathbf{X}\mathbf{W}_K
    V=XWV\mathbf{V} = \mathbf{X}\mathbf{W}_V
  3. Scaled Dot-Product Attention:
    Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right)\mathbf{V}
    여기서 $$d_k$$는 key 벡터의 차원입니다.
  4. Masked Self-Attention:
    MaskedAttention(Q,K,V)=softmax(QKTdk+M)V\text{MaskedAttention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}} + \mathbf{M}\right)\mathbf{V}
  5. Multi-Head Attention:
    MultiHead(X)=Concat(head1,...,headh)WO\text{MultiHead}(\mathbf{X}) = \text{Concat}(\text{head}_1, ..., \text{head}h)\mathbf{W}O
    **where headi=Attention(XWQi,XWKi,XWVi)\text{where head}i = \text{Attention}(\mathbf{X}\mathbf{W}{Q_i}, \mathbf{X}\mathbf{W}{K_i}, \mathbf{X}\mathbf{W}{V_i})
  6. 최종 Attention 출력:
    Z=LayerNorm(X+MultiHead(X))\mathbf{Z} = \text{LayerNorm}(\mathbf{X} + \text{MultiHead}(\mathbf{X}))

Glossary of key terms

Circuits: model 의 subgraph로 모델이 특정 행동을 할 때 주로 활성화 되는 부분들을 나타내는 것.

Residual Stream: residual stream은 지속적인 정보 흐름을 가능하게 하는 요소. 낮은 레이어에서 정보를 작성하고, 높은 레이어에서 이를 계속 읽을 수 있게 해줌.

QK Matrix & OV Matrix: attention 수식을 약간만 다르게 보면 아래와 같은 관점으로 볼 수 있다.

이때, QK Matrix는 attention score를 만들 때 사용되는 weight로 특정 residual stream에서 정보를 읽는데 활용되는 요소이다.

또한 OV Matrix는 residual stream에 정보를 작성해 보내주는 역할을 하는 요소이다.

Activation Patching: 모델 최종 decision에 영향을 분석하는 방법. 특정 layer의 value를 빼거나 perturbation을 진행해 최종 결과에 대한 value의 영향을 분석함.

Ablation study: component를 빼보면서 output이 어떻게 영향을 받는지 분석하는 방법 (연산량 문제로 잘 사용하지 않는 방법)

Overview of special attention heads

최근 연구들은 scaling law를 따라 모델을 학습시켰을 때 특정 파라미터 이상에 emergent abilities 가 나오는 것을 확인했지만 이런 현상을 통해 왜 모델이 더 좋은 성과를 낼 수 있는지 완전히 이해하지 못했다.

이를 해결하기 위해 최근 연구들은 internal mechanisms of LLM 을 보기 시작했고, attention head의 기능들에 집중하고 있다.

본 논문에서 human congnitive paradigm에서 영감을 받은 4-stage framework로 attention head의 역할을 구분하고 설명하고자 한다.

Knowledge Recalling (KR): 문제 해결을 위해 관련된 이슈들을 생각해보는 과정

In-context Identification (ICI): text적인 정보뿐만 아니라 의미적 문맥적 정보 파싱

Latent Reasoning (LR): 결과 내기위해 정보를 통합하는 과정

Expression Preparation (EP): 자연어의 형태로 표현하는 과정

Knowledge Recalling (KR)

LLM 의 parametric knowledge는 학습과 fine-tuning 을 통해 저장되고 사람과 비슷하게 LLM 내부 저장된 지식을 상기시키는 attention heads가 존재한다. 이 attention head는 common sense or domain-specific expertise를 이후 reasoning 에 사용할 수 있게 한다.

KR 에 해당하는 head는 일반적으로 컨텍스트 내의 특정 내용을 기반으로 초기 추축이나 지식을 검색하고, 메모리 정보를 초기 데이터 또는 보충 정보로 잔여 스트림에 주입한다.

  • General tasks attention head는 학습과정에서 점진적으로 정보를 모으로 내부 지식을 가져오는 associative memories를 가지고 있다. 이 head는 superposed activation feature의 노이즈를 제거하며 본질적인 feature만들 보존하려한다. 또한 dimension이 커질수록 유용한 memories와 연결되고 관련정보를 가져오는 능력이 커진다. 이는 Memory Head라고 부르며 내부 파라미터에서 관련정보를 가져올 수 있는 특징을 가진 head이다. Memory head는 관련 속성들은 residual stream으로 가져오는 역할을 한다.
  • Specific task Multiple Choice Question Answering (MCQA)테스크의 경우 Constant Head 라는 모든 선택지에 동등한 attention score를 주는 head가 존재한다. 이와 동시에 Single Letter Head라는 head는 하나의 option에만 높은 attention score를 주게되며 답변을 선택하는데 영향을 준다고 말한다. 반면 Binary Decision Tasks (BDT) 테스크, yes-no 를 결정하거나 answer verification 하는 테스크에서 모델 (gemma, mistral, llama, qwen, gpt) 시리즈 들은 negative bias를 가지고 있어 우전석으로 부정적인 답변을 고르는 경향이 있다고한다. 이에 영향을 주는 head를 Negative Head라고 한다.

In-context Identification (ICI)

In-context 를 이해하기 위해서 head는 structural, syntactic, and semantic information을 QK matrix 확인하며, OV matrix를 통해 residual stream으로 정보를 전달해준다.

  • Overall Structural Information Identification Positional Head, Previous Head는 token sequence에서 positional relationship을 캡처하는 head로 역할을 한다. Duplicate Head는 content가 반복되는 내용을 캡처하며 토큰이 더 반복적으로 나타날 수록 이부분에 집중하는 특징을 가진다. 반대로 Rare Words Head는 거의 빈도가 적은 token에 포커스를 두는 헤드로 작동을한다. 또한 long text (Needle in a haystack) 과 같은 테스크를 할 때 Retrieval Head는 정확하는 특정 토큰의 위치를 잡아낼 수 있게 돕는다.
  • Syntactic Information Identification Syntactic Head는 nominal subjects, direct objects, clauses등의 문장 구조를 구분짓는 역할을 한다. Subword Merge Head는 토크나이저로 나눠진 하나의 단어 (subword)를 합쳐주는 역할을 한다.
    Mover Head는 argumnet parser의 역할을 한다. 이 헤드는 sentence에서 중요한 단어를 파악해 LLM decoding 할 때 전달해 주는 역할을 한다. 이는 요약 혹은 reasoning 에 도움이 된다고 한다.
  • Semantic Information Identification Context Head는 의미적 정보를 확인하고 이때 현재 테스트에서 관련이 있는 정보를 추출하는 역할을 한다. Semantic Induction Head는 문장 간 혹은 문장과 전체 파트의 의미적 관계를 캡처하는 역할을 한다.

Latent Reasoning (LR)

KR과 ICI 단계에서는 정보를 모으는데 초점은 둔다. LR 에서는 모아진 정보를 합성하고 logical reasoning 역할을 하는 문제해결에 가장 중요한 요소이다. QK 메트릭스를 통해 읽은 정보를 바탕으로 head에서 내재된 reasoning을 진행하고 이를 OV matrix 를 통해 residual stream으로 작성하는 역할을 한다.

  • In-context Learning Induction Heads는 “… [A][B] … [A]” 와 같은 형태의 패턴에 대해서 [B]라는 패턴을 예측하게 돕는 헤드이다. 이와 유사하게 In-context Head에서는 QK matrix 에서 각 in-context 마지막 토큰 포지션 (5-shot 이라고 하면 5개의 토큰 포지션)에 대해 답변생성하는 곳에서 유사도를 계산하는 역할을 한다.
  • Effective Reasoning Truthfulness Head와 Accuracy Head는 head의 활성화 방향으로 hook을 수정하면 QA task에서 reasoning 능력을 향상시킬 수 있다. 유사하게 Consistency Head는 LLM에 다양한 방식으로한 비슷한 질문에 대해 internal consistency를 활성화 시키는 역할을 한다.

Expression Preparation (EP)

EP stage에서는 reasoning 결과를 통합해 text로 표현해주는 단계이다.

Mixed Head는 ICI와 LR 단계에서 온 head의 정보를 선형결합을 통해 가져오는 역할을 한다. 통합된 결과는 residual stream에 작성되며 vocab logit 값에 맵핑된다.

몇몇 EP head들은 signal amplification 역할을한다. 대표적으로 Amplification Head와 Correct Head가 있는데 이 둘은 앞서 가져온 head의 정보들을 표현될 수 있게 하는 역할을 한다.

또한 reasoning 결과와 user’s instruction을 alignment시키는 head들이 있다. Coherence Head는 다국어 테스크에서 target language 를 잘못 말하게 되는 경우가 있는데 해당 head를 증폭시키면 user가 원하는 target language 로 답변이 나오게 된다.

마지막으로 Faithfulness Head는 model internal state의 reasoning을 반영시켜 CoT로 표현시켜주는 역할을한다. 이는 CoT 결과를 더 robust 하고 consistent 하게 만든 ㄴ역할을 한다.

Unveiling the discovery of attention heads

위에서 언급한 head들은 검출하는 방법에 대한 소개. 2가지 방식으로 나눌 수 있으며 각각 Modeling-Free 방식과 Modeling-Required 방식이 있다.

Modeling-Free

modeling-free 방식은 latent state 를 변형하거나 매핑시켜 logit 값을 변경시키는 과정이다. 주로 Activation Patching혹은 Ablation Study 가 사용된다.

Modification-Based Method

Directional Addition: specific latent state 보유하고 있는 기존 information 에서 같은 head 위치에 state를 더해주는 방법. 대표적인 예시로 positive, negative sentiments에 대한 representation을 추출하고 이를 각각에 감정에 더해줬을 때 변화하는 것을 관찰하는 방식이 있음. 더할 때 감정에 해당하는 head가 어느 부분인지 찾아내는데 도움이 된다.

Directional Subtraction: 위의 방식과 반대로 특정 state 정보를 제거해서 모델이 해당 행동을 안하게 하는 방식이 존재함.

Replacement-Based Methods

Modification-Based methods와 반대로 specific latent state를 다른 값으로 다 바꿔버리는 방법이다. 이는 기존 head에서 나온 original prompt 의 값에 corrupted prompt 값으로 바꿔 head의 특정 테스크에서의 영향력을 확인하는 방식이다.

Modeling-Required

해당 방법은 외부적인 모델링을 통해 특정 head의 기능을 확인하는 방법이다.

Training-Required methods

대표적인 방식으로 Probing 방식이 있다. 해당 방식은 activation value를 여러 해드로 부터 추출하고 각 분류된 해드에대해 classifier를 학습시키는 방식이다. 이는 activation pattern과 head의 관계를 볼 수 있게 한다.

다음으로 dataset으로 추가학습시킨 모델과 original model과 비교하는 방식이다. 이때 특정 헤드가 동일한 역할을 할 때 활성화되는 여부를 본다.

Training-Free methods

특정 현상에 대해 score는 매길 수 있으면 학습없이 내부 속성과 모델 행동간의 관계를 파악할 수 있다. 대표적으로 Retrieval Head의 Retrieval score 에서는 Needle in a Haystack 테스크에서 정보를 뽑아오는데 어떤 head가 영향을 주는지 파악하기 위해 사용했다.

0개의 댓글