[논문리뷰] Titans : Learning to Memorize at Test Time

서정훈·2025년 9월 23일

LLM

목록 보기
1/4
post-thumbnail

느낀점

transformer 기반 모델이 긴 시퀀스를 input으로 가질 때의 문제점을 해결하기 위해 새로운 모델 제안
기존의 transformer를 뛰어넘는 성능을 지속적으로 보여온다면, 대 transformer의 시대의 막을 내릴수도 있을 것 같다.
점점 더 인간의 뇌 구조에 유사해지는 모델들이 등장 및 발전해오고 있음을 시사

Abstract

RNN : 데이터를 고정 크기의 메모리에 압축
Transforemr : 전체 context window에 주목하여 모든 토큰 간의 직접적인 의존성 포착
\rarr 계산 비용이 quadratic으로 증가하는 문제가 있어, 결국 모델이 다루는 context를 제한하여 사용

Titans : 기존 attention을 활용한 단기 메모리 모듈(short-term memory)로 작동하고, 신경 기반 기억으로 장기 기억(long-term memory)로 작동
\rarr 메모리를 병렬적으로 학습하며 빠른 추론 속도 보장

  • Titans라는 새로운 아키텍쳐 계열으로, 메모리를 효과적으로 통합하는 3가지 변형 형태까지 제안
  • Titans는 기존 방법들보다 더 높은 정확도로 200만 이상의 문맥 창까지 확장

1. Introduction

Transformer는 현재 context window 내의 토큰 간 직접적인 의존성만 고려한다. 이는 문맥 길이가 길어지면서 메모리 복잡도가 quadratic이기 때문에, 문맥 길이가 긴 복잡한 과제에서는 transformer의 적용이 어렵다.

이러한 문제를 해결하기 위해서 다양한 Linear Transformer 연구가 이루어지고 있다.

Linear Transformer에서는 Attention에서 사용한 Softmax 연산을 Kernal Function으로 대체하여, 메모리 사용량을 크게 줄여 연산 효율성을 높이고, 긴 context를 비교적 잘 처리
But, 긴 문맥 정보를 작은 벡터 또는 행렬 상태로 압축하는 것이기 때문에 성능 저하는 불가피

즉 특정한 기능을 담당하는 개별 모듈들을 가진 채, 이를 연결하는 시스템을 구성하는 것이 필요하다고 말함

Memory Perspective

신경심리학에서 흔히 사용되는 기억과 학습의 정의

  • 기억 : 입력에 의해 유도된 신경 업데이트
  • 학습 : 주어진 목적에 따라 효율적이고 유용한 기억을 획득하는 과정

RNN의 경우

  • 벡터 형태의 메모리 사용
  • 새로운 입력 xtx_t에 대해 이전 메모리 Mt1M_{t-1}을 사용하여 MtM_t 업데이트
    f(Mt1,xt)f(M_{t-1},x_t)
  • 입력에 대응하는 Memory를 검색하는 과정 수행
    g(Mt,xt)g(M_t,x_t)

Transformer의 경우

  • 행렬 형태의 메모리 사용
  • key-value 쌍을 메모리에 압축 없이 계속 추가하며 메모리 업데이트
  • query 벡터와 key 벡터 간의 유사도를 계산하여, output vector를 생성하기 위해 value 벡터에 가중치 적용

효과적인 메모리 구조를 설계하기 위한 다섯 가지 질문
1. 좋은 메모리 구조란 무엇인가?
(What constitutes a good structure for the memory?)
2. 적절한 메모리 업데이트 방식은 무엇인가?
(What is proper memory update process?)
3. 효율적인 메모리 검색
(What is a good memory retrieval process?)
4. 서로 연결된 메모리 모듈들을 효과적으로 결합한 아키텍처를 어떻게 설계할 것인가?
(How to design and efficient architecture that incorporates different interconnected memory modules?)
5. 과거의 정보를 효과적으로 저장하고 기억하기 위해 더 깊은 메모리 모듈이 필요할까?
(Is a deep memory module needed to effectively store/remember long past?)

Contributions and Roadmap

테스트 시간에 암기를 학습할 수 있는 장기 메모리 모듈 설계

Neural memory

  • in-context 모델처럼, test를 할 때, 데이터를 파라미터에 저장하고 기억하는 방법을 학습
  • 인간의 장기 기억 시스템에서 영감, 사람이 예측하기 어려운 정보(Surprising Event)를 더 잘 기억
  • Associative memory loss를 사용하여 입력에 대한 Surprise 정도를 판단하고, 이는 입력에 대한 신경망의 gradient를 통해 측정
  • Decay Mechnism을 도입하여, memory 크기와 surprise 정도를 고려한 동적 업데이트 수행. RNN 모델에서 사용되는 forgetting mechnism을 일반화한 형태

Titans Architectures

3가지 핵심 모듈로 구성

  1. Core

    • 단기 기억 담당 \rarr 데이터 흐름 주도
    • 제한된 attention 사용
  2. Long-term Memory

    • 신경 기반 장기 메모리 \rarr 과거 정보 저장하고 기억하는 역할
  3. Persistent Memory

    • 학습 가능한 데이터 독립적 파라미터들로 구성 \rarr 작업에 대한 지식을 담고 있음

메모리 처리를 위한 3가지 변형 모델 제안
1) context로써의 메모리
2) layer로써의 메모리
3) gated branch로써의 메모리

2. Preliminaries

M\mathcal{M} : 신경망 메모리 모듈
Q,K,VQ,K,V : 어텐션 메커니즘의 쿼리,키, 값
MM : 어텐션 마스크

시퀀스 분할 시, S(i)S^{(i)} 는 i 번째 세그먼트
Sj(i)S_j^{(i)} : i번째 세그먼트의 j 번째 토큰
tt : 시간 인덱스

신경망 N\mathcal {N}과 샘플 xx에 대해,
N(x)\mathcal{N(x)} : weight 업데이트를 포함한 정방향 패스
N(x)\mathcal{N^*(x)} : weight 없이 단순 추론만 수행하는 경우
N(k)\mathcal{N^{(k)}} : 신경망의 k번째 계층

2.1 Backgrounds

Attention

Query, Key, Value를 정의하고, softmax 기반의 weighted sum을 통해 계산.
기존 attention과 동일

Efficient Attentions

  • 긴 시퀀스에서 softmax 어텐션의 메모리 소비와 throughput 문제를 해결하기 위한 많은 연구 진행
    ex) I/O 최적화, 희소화, softmax 근사화, 커널 기반 선형 어텐션
  • Linear Attention(선형 어텐션) : softmax 대신 커널함수 ϕ(x,y)=ϕ(x)ϕ(y)\phi(x,y)=\phi(x)\phi(y) 사용
  • 커널 함수를 항등으로 설정 (ϕ(x)=x\phi(x)=x)
    Mt=Mt1+KtTVtM_t = M_{t-1}+K^T_t V_t
    yt=QtMty_t = Q_t\cdot M_t

Moddern Linear Models and Their Memory Perspective

  • RNN의 은닉상태는 학습된 정보가 저장되는 메모리 유닛으로 볼 수 있음
    \rarr 정보 저장 및 검색 효율적이지만, 메모리 오염 문제

    solution
    1. 망각 게이트 추가
    GLA, LRU, Griffin, xLSTM, Mamba2 등 동적 게이팅 구조 도입하여 메모리 조절
    2. 쓰기 연산 개선
    델타 학습 규칙(이전 값 제거 후 새 값 반영), Gated DeltaNet 등 수행

Memory Modules

  • 기존 방법 : Fast Weight Programmers, Hebbian Learning, Delta Rule
  • 이 방법들은 Surprise 기반이 아니고, 망각 매커니즘이 없고, 토큰 흐름을 고려하지 못한다.

3. Learning to Memorize at Test Time

3.1 Long-term Memory

신경 메모리의 설계 동기와 구조 설명하는 part

  • 신경 기반 장기 기억 모듈 설계를 위해서는, 과거의 내용을 추상화하여 파라미터에 인코딩할 수 있는 모델이 필요하다. 단순하게 신경망을 학습시켜 학습 데이터를 암기할 경우, 모델의 일반화 성능을 제한하거나 개인정보 보호 문제, 테스트 시의 성능 저하 등의 문제가 발생한다.
    \rarr Meta-Learning 기반의 Memory 모델이 필요. 이는 test time에 데이터를 어떻게 암기하고 잊을지를 학습

Learning Process and Surprise Metric

핵심 idea : 과거 정보 x1,,xt1x_1,,x_{t-1}을 장기 신경 기억 모듈 M\mathcal{M}의 파라미터에 압축 \rarr 온라인 학습 문제로 다룸

  • 모델의 놀라움(기대와 다른 것)을 gradient로 보아, gradient가 클수록 현재 입력 데이터가 과거 데이터와 다름을 의미하도록 구성

    신경 기억 모듈 M\mathcal{M} 업데이트
    Mt=Mt1θt(Mt1;xt)M_t = M_{t-1} - \theta_t\nabla \ell(M_{t-1};x_t)
    여기서, surprise score =(Mt1;xt)= \ell(M_{t-1};x_t)

  • surprise score 단점 : 매 surprising moment 이후 나타나는 중요한 정보를 놓칠 수 있음 \rarr gradient가 단계를 거칠수록 매우 작아져 local minimum에 빠지고, 시퀀스의 일부 정보를 놓칠 수 있음

    improvement
    surprise metric을 past surprise와 momentary surprise로 나눔
    - past surprise(St1S_{t-1}) : 아주 최근의 놀라움 정도
    - momentray surprise((Mt1;xt)\nabla \ell(M_{t-1};x_t)) : 들어오는 데이터의 놀라움 정도

    \rarr momentum을 사용하는 경사하강법과 유사
    StS_t : 시간에 걸친 surprise 기억 역할을 하는 모멘텀
    ηt\eta_t : 데이터에 surprise decay, 시간에 따라 surpirse가 어떻게 소멸되는지를 제어하는 함수
    θt\theta_t : 순간적 surpris가 최종 surprise 척도에 얼마나 반영되어야 할지를 조절하는 항목

η\eta는 다음 같은 경우에 따라 다르게 조절
1. 문맥의 변화로 인해 last surprise를 무시해야하는 경우 : ηt0\eta_t \rarr 0
2. 현재 토큰이 직전 토큰들과 매우 관련이 있어 last surprise를 통합하는 경우 : ηt1\eta_t \rarr 1

objective

손실함수 (;)\ell(\cdot;\cdot) = 메모리가 test time에 따라야 할 목표 함수

본 연구에서 주로 다루는 Memorysms Associative Memory

  • 과거 데이터를 key-value 쌍으로 저장
  • 입력 xtx_t일 때 2개의 Linear layer을 사용하여 xtx_t를 key와 value로 변환
    kt=xtWK,vt=xtWVk_t=x_tW_K , v_t=x_tW_V

이렇게 key,value를 정의한 이후, key와 value 간의 관계를 학습해야하기에 손실함수를 다음과 같이 정의

  • Mt1(kt)\mathcal {M}_{t-1}(k_t) : 현재 Memory state에서 key ktk_t에 대응하는 value
  • vtv_t : 실제 value

inner loop에서 Memory module의 weight를 최적화하고, outer loop에서 전체 아키텍처의 나머지 파라미터 WK,WVW_K,W_V 최적화

Forgetting Mechanism

긴 시퀀스를 처리할 때에는 모델이 어떤 과거 정보를 잊어야 할지 효율적으로 관리하는 것이 중요
본 논문에서는, Adaptive Forgetting Mechanism 사용

αt[0,1]\alpha_t \in [0,1] : Memory를 얼마나 지울지 결정 \rarr 0이면 기존 Memory 유지, 1이면 전체 Memory 초기화

Memory Architecture

본 논문에서는 장기 기억의 아키텍쳐로 최소 1개 층 이상의 MLPs 사용
why? 장기 기억의 설계를 잘 동기화하고, 그것이 어떻게 아키텍처 내에 통합될 수 있는지를 중심적으로 다루기 위해

  • 벡터값 or 행렬값 메모리를 사용할 때, 메모리 모듈은 과거 데이터를 압축하여 선형화된 형태로 맞춤

Retrieving a Memory

어떻게 메모리에서 정보를 검색할지를 결정
본 논문에서는 Memory 검색에 있어 weight를 업데이트하지 않고 단순 Forward pass 사용

메모리 검색 순서
1) Query Vector 생성
qt=xtWQq_t = x_tW_Q
2) Memory Retrieval
yt=M(qt)y_t = \mathcal{M}^*(q_t)
memory에서 retrieve 실행할 때 weight 업데이트 진행하지 않고 단순 inference 사용

3.2 How to Parallelize the Long-term Memory Training

long-term memory module의 학습에 사용되는 Associative Memory Losss는 결국 gradient descent+momentum+weight decay를 최적화 하는 과정과 동일
+ 본 논문에서는 Tensorization, Matmul 연산을 활용해 학습 병렬화 수행

Parameters as the Function of Chunks

기존 : αt,θt,ηt\alpha_t,\theta_t,\eta_t와 같은 파라미터들이 입력 데이터 xtx_t에 따라 변화하는 input-dependent한 방식으로 모델의 표현력을 증가시킴 \rarr 계산량 많음
본 논문 : chunk 내에서는 모든 토큰이 동일한 파라미터 값을 공유 \rarr 모델의 표현력은 조금 줄어들지만 계산량이 줄어 학습 속도를 크게 개선

3.3 Persistent Memory

효과적인 Memory System이 구축되기 위해서는 input에 의존하는 파라미터 뿐만 아니라 input과 무관한 파라미터 또한 포함해야 함
\rarr Persistent Memory : 입력과 관계없이 학습 가능한 특정 task에 대한 정보 저장
P=[p1,p2,,pNp]P = [p_1,p_2,\cdot\cdot\cdot,p_{N_p}] , xnew=[p1,p2,,pNp]xx_{new} = [p_1,p_2,\cdot\cdot\cdot,p_{N_p}]||x

Persist Memory 도입 이유
1) 메모리 관점
기존 long-term memory는 input에 따라 다르게 동작하는 Contextual memory였지만, 효과적인 memory system이라면 task 자체에 대한 지식이 있어야 함
\rarr Persistent Memory에 작업 지식 정보 저장
2) Feedforward Network 관점
fully connected layer의 ReLU활성화 함수를 softmax로 변경하면 해당 layer가 attention처럼 작동
\rarr Persistent Memory 모델이 특정 task에서 input과 독립적인 형태로 attention을 조절하는데 기여
3) 기술적 관점
transformer의 causal masked attention은 초기 토큰들에 높은 가중치를 할당하는 경향이 있어, 후반부 토큰 정보가 충분히 반영되기 어려움
\rarr Persistent Memory를 사용해 attention wieght를 균등하게 분배

4. How to Incoporate Memory?

4.1 Memory as Context


long-term Memory 역할을 수행하는 첫 번째 방법 : Memory를 Context로써 사용
즉, 현재 정보를 처리할 때 과거 정보를 context로 사용하는 것으로 다음의 순서에 따라 계산

1단계) Retreiving Memory

긴 시퀀스 x가 주어졌을 때, 이를 고정된 chunk 크기의 세그먼트 S(i)S^{(i)}로 분할
현재 S(t)S^{(t)}가 들어오면 이전 S(t1)S^{(t-1)}를 과거 정보로 사용, 이전 메모리 Mt1M_{t-1}에 Query를 입력하여 해당 세그먼트와 관련된 과거 정보 검색

2단계) Memory 활용하여 Attention 수행

1단계에서 검색을 통해 찾은 hth_t를 현재 입력과 함께 Attention Module에 전달. 이 때 Persistent Memory PP 도 함께 포함하여 새로운 입력 시퀀스 구성

이러한 새로운 시퀀스를 attention module에 전달하여 최종 출력 얻음

3단계) Memory Update

2단계에서 얻은 Attention module의 출력을 사용하여 Long-term Memory 업데이트
최종 출력 oto_t : Attention 결과로 생성된 yty_t와 업데이트된 memory Mt\mathcal{M}_t^*를 결합하여 생성

위와 같은 방식의 장점 3가지
1) Attention이 과거 정보가 필요한지 여부를 스스로 결정
과거 memory를 무조건 사용하는 것이 아닌 필요할 때만 사용
2) Attention이 Memory에 저장할 정보의 중요도 조절
모든 데이터를 기억하지 않고 중요한 정보만 저장하여 memory overflow 방지
3) Test-time에서도 Long-term Memory가 계속해서 학습
Persistent Memory는 고정되어 있고, Attention module은 새로운 입력을 바탕으로 최적의 정보를 계속 선택하면서 long-term memory가 계속 업데이트 된다.

4.2 Gated Memory


SWA와 Long-term Memory를 결합하는 아키텍처

  • 입력을 chunk로 나누지 않고, 전체 sequence를 그대로 사용
  • SWA : Short-term Memory역할 수행
  • Neural Memory : Long-term Memory 역할

1단계) long-term memory

input xx에 persistent memory를 추가하여 새로운 입력 x~\tilde x 생성

2단계) sliding window attention(SWA)

SWA를 사용하여 현재 context에서 필요한 정보 학습 후 두 결과를 결합하여 최종 출력 생성

단기기억으로 SWA, 장기기억으로 neural memory를 구성하여 독립적인 branch를 사용하므로 multi-head attention과 유사한 효과를 얻을 수 있음

4.3 Memory as a Layer

neural memory를 neural network의 독립적인 layer로 쌓은 방법

  • RNN과 SWA를 같이 사용하던 Hybrid 모델 구조와 유사
  • persistent memory 추가 \rarr input을 Neural Memory에 통과 \rarr SWA 적용하여 최종 출력 생성
  • Neural Memory를 순차적으로 연결하고, 독립적인 layer처럼 사용

하지만, 이 방식은 attention과 Neural Memory의 장점을 모두 활용하기 어려움
모델의 성능이 각 layer의 표현력에 의존하기 때문에 각 layer가 강력하지 않으면 전체 모델 성능이 제한되고, attention과 neural memory간의 상호작용이 부족

Memory Without Attention

Attention 없이 Neural Memory만 사용하는 버전 소개

  • MAL에서 Attention을 제거하고, Neural Memory만 사용하는 것으로 LMM 또는 Titans라고 부름
  • 그래도 논문에서 이 모델을 주장한 이유는 진정한 Long-term Memory라면 Short-term Memory 없이도 독립적으로 강력한 성능을 내야한다라고 생각하기 때문이다.

4.4 Architectural Details

titans 모델을 효과적으로 학습시키기 위해 사용한 기술들
1. Residual Connection
-모든 Block에서 residual connection 사용
2. siLU 활성화 함수 사용
-siLU은 ReLU보다 더 Smooth하고 Gradient 소실 문제가 적음
-최근 Transformer와 같은 대형 모델에서 더 좋은 성능을 보임

  1. L2 Norm 사용
    -Query, Key의 크기를 일정하게 유지
  2. 1D Depthwise-Separable Convolution 추가
    -Query,Key,Value 계산 후 1D Convolution 사용하여 성능 향상
  3. Gating Mechanism
    -출력 projection 전에 정규화 + Gating 적용
  4. Theorem 4.1.
    긴 context 문제가 있어 복잡한 문제 해결에 어려움

5. Experiments

앞서 위에 언급한 다섯가지 질문에 대해서 실험할 예정
1. How do Titans perform compared to baselines in downstream tasks? \rarr 5.2, 5.6, 5.7
2. What is the actual context length of Titans? \rarr 5.3, 5.4
3. How do Titans scale with respect to context length? \rarr 5.8
4. How the depth of memory can affect both performance and efficiency? \rarr 5.5
5. What is the contribution of each Titans' component in its preformance? \rarr 5.9

5.1 Experimental Setup

model

Titans : MAC, MAG, MAL, Neural Memory
비교 모델 : Transformer++, ResNet, GLA, Mamba, DeltaNet, TTT, Gated DeltaNet 등

scale

170m, 340m, 400m, 760m parameters

Dataset

  • FineWeb-Edu dataset
  • trained on 15B tokens sampled from the dataset
  • trained on 30B tokens from the dataset

5.2 Language Modeling

Perplexity와 Accuracy로 모델 성능 평가

  • LMM : 하이브리드 모델이 아닌 단일 아키텍처들과 비교했을 때, 모든 baseline model보다 더 뛰어난 pelplexity
  • Titans 하이브리드 변형 : Samba와 Gated DeltaNet-H2와 같은 다른 하이브리드 모델보다 더 뛰어난 성능
  • Titans 간 비교 : MAC와 MAG가 MAL보다 전반적으로 더 나은 성능

5.3 Needle in a Haystack

각 시퀀스에서 특정 정보를 얼마나 효과적으로 검색할 수 있는지 평가!

  • Titans 모델들이 baseline models 보다 더 좋은 성능을 보임
  • 특히 MAC 모델이 가장 좋은 성능을 보임
    \rarr 효과적인 메모리 용량 관리(forgetting mechanism)와 메모리 업데이트 규칙(momentum)이 장기 컨텍스트 처리에 중요함을 보여줌

5.4 BABILong Benchmark

극단적으로 긴 문서에서의 추론 능력을 평가
Few-shot learning과 Fine-Tuning Setting으로 실험

  • Titans MAC가 훨씬 적은 파라미터 수, 소규모로 미세 조정되었음에도 불구하고 모든 baseline models보다 뛰어난 성능
    \rarr In-context Online Memory Learner가 과거 데이터를 효과적으로 저장하는데 큰 역할 수행

5.5 The Effect of Deep Memory

Neural Memory의 깊이가 모델 성능과 학습 속도에 미치는 영향 평가
Neural Memory 모듈의 깊이를 1~4로 변경하면서 실험 진행

  • 깊은 Memory일수록 성능이 증가
  • seqeunce length에 따라 모든 모델의 학습 속도가 선형적으로 증가

5.6 Time Series Forecasting

  • Long-term Dependency, Weight Decay, Surprise Metric 으로 인해 좋은 성능
    \rarr 자연어 뿐만 아니라 시계열 예측과 같은 다른 task에서도 뛰어난 성능

5.7 DNA Modeling

  • LMM은 다양한 다운스트림 유전체학 작업에서 최첨단 아키텍처들과 경쟁력 있는 성능
    \rarr 자연어 뿐만 아니라 DNA와 복잡한 biological sequence 데이터에서도 효과적으로 작동

5.8 Efficiency

  • 학습 속도 : MAL 모델이 Titans 모델 중에서 가장 빠름

5.9 Ablation Study

업로드중..
모든 구성요소가 성능에 기여
weight Decay를 뺐을 때 성능이 가장 크게 감소

6. Conclusion

Titans의 핵심 아이디어 및 LMM의 특징

  • LMM은 test time에도 학습하는 meta in-context learner로 설계
  • surprising 정보나 surprising tokens에 가까운 정보들을 적응적으로 기억
  • 기존의 RNN 모델과 비교했을 때, 더 expressive한 메모리 업데이트 및 저장 메커니즘 가짐
    - 모멘텀 기반 규칙, 깊은 비선형 메모리, 망각 메커니즘
profile
Data Scientist

0개의 댓글