xLSTM: Extended Long Short-Term Memory (arxiv 2024)
Abstract
1990s -> Long Short-Term Memory with constant error carousel and gating
LSTMs -> Contributed to numerous deep learning success stories, with first LLMs.
Transformer with paralleizable self-attention outpaces LSTMs at scale
How far do we get in language modeling when scaling LSTMs to billions of parameters, leveraging the latest techniques from modern LLMs, but mitigating known limitations of LSTMs?
Exponential gating with appropriate normalization and stabilization techniques.
Modify LSTM memory structure
sLSTM: scalar memory, scalar update, new memory mixing
mLSTM: fully parallelizable with matrix memory and a covariance update rule
Integrating s,mLSTM into residual block backbone -> xLSTM blocks -> residually stacked into xLSTM architectures
Perform favorably in performance and scaling compared to Transformers and State Space Models
1. Introduction
LSTM
(Hochreiter, 1991; Hochreiter & Schmidhuber,1997)
Architecture
Constant error carousel and gating to overcome the vanishing gradient problem of RNNs
Constant error carousel -> additive update of the cell state ct−1 by cell inputs zt with sigmoid gates(blue)
The input gate it and the forget gate ft control the update
The output gate ot controls the output of the memory cell(hidden state ht)
Cell state is normalized or squashed by ψ.
History
successfully applied to various domains
prevailed over text generation until Transformers in 2017.
still used in highly relevant applications and have stood the test of time
Limitations
Inablity to revise storage decisions
Left figure - low performance in Nearest Neighbor Search
Limited storage capacities (compressed into scalar cell state)
Right figure - low performance in Rare Token Prediction
Lack of parallelizability due to memory mixing (enforced sequential processing)
What performances can we achieve in language modeling when overcoming these limitations and scaling LSTMs to the size of current Large Language Models?
2. Extended Long Short-Term Memory
Main modification -> exponential gating and novel memory structure
memory cell -> input, output, forget gate (forget gate was added by Gers 2000)
LSTM memory update rule at t
w: input weight vectors
r: recurrent weights between hidden state and cell input, input gate, forget gate, output gate
b: corresponding bias terms
ϕ: cell input activation function
ψ: hidden state activation function -> normalize or squash the cell state
All gate activation function are sigmoid
2.2 sLSTM
Exponential gates
to empower LSTMs ability to revise storage decisions
with normalization and stabilization
Input & Forget gates can have exponential activation functions
Normalizing exponential gates
Exponential gating can lead to overflow
Use normalizer state mt
sums up the product of input gate times all future forget gates (?)
Using f′ and i′ does not change the output of the whole network, nor changing the derivatives of the loss
New Memory Mixing
Multiple memory cells like the original LSTM
Multiple memory cells enable memory mixing via recurrent connections R from hidden state vector h to memory cell input z and the gates i, f, o
Can have multiple heads with memory mixing within each head (not across heads)
2.3 mLSTM
LSTM memory cell (c∈R) to a matrix C∈Rn×n -> enhance storage capacity
retrieval is performed via a matrix multiplication
Covariance update rule
Store key kt∈Rd and value vt∈Rd at time t
Value vt should be retrieved by a query vector qt+τ∈Rd at time t+τ
Covariance update rule for storing a key-value pair
Ct=Ct−1+vtkt⊤
key & value -> zero mean due to layer-norm
Covariance update rule is optimal for maximal separability of retrieved binary vectors (maximal signal/noise ratio)
Forget gate corresponds to decay rate and the input gate to the learning rate of Fast Weight Programmers
normalizer state
weighted sum of key vectors
key vector weighted by the input gate and all future forget gates
Use absolute value of dot product of query and normalizer state, lower bound it by a threshold (following Retentive network)
Multiple memory cells with multiple heads and multiple cells (no memory mixing)
Same exponential gate with same stabilization techniques
No memory mixing -> parallelizable!
2.4 xLSTM Architecture
xLSTM Blocks
The block should non-linearly summarize the past in high-dimensional space to better separate contexts
Seperating history -> prerequisite to correct prediction of next sequence element.
Cover's Theorem
In a higher dimensional space non-linearly embedded patterns can more likely be linearly separated than in the original space.
Residual blocks
Residual block with post up-projection (like Transformers - left)
non-linearly summarizes the past in the original space
-> linearly maps into a high-dim space
-> apply non-linear activation function
-> linearly maps back to the original space
=> For xLSTM block with sLSTM
Residual block with pre up-projection(like SSMs - right)
linearly maps to a high-dim space
-> non-linearly summarizes the past in the high-dim space
linearly maps back to the original space
=> For xLSTM block with mLSTM (memory capacity issue)
sLSTM diagram
mLSTM diagram
xLSTM Architecture
residually stacking building blocks
pre-LayerNorm(most common)
2.5 Memory and Speed Considerations
Linear
Linear computation and a constant memory complexity to seq length
Compressive memory
trade-off
memory of mLSTM does not require parameters, but computationally expensive (d×d matrix memory & d×d update)
trade off memory capacity against computational complexity
Parallel computation with GPU -> minor effect
CUDA for sLSTM
Not parallelizable due to memory mixing
CUDA implementation with GPU memory optimizations to the register level
less than 2 times slower than mLSTM
3 Related Work
Linear Attention
Synthesizer(learns synthetic attention)
Linformer(self-attention by low-rank matrix)
Linear Transformer(linearizes attention mechanism)
Performer(approximate attention with positive orthogonal random feature)
SGConv (replace attention by fast long convolution)
State Space Models
linear in context length
S4, DSS, GSS, S5, BiGS, H3, Mamba
Recurrent Neural Networks
recent development of RNN in LLMs
LRUs(RNNs with Deep Linear Recurrent Units)
Hierarchically Gated Linear RNN(HGRN, HGRN2)
RWKV
Gating
Key ideas of LSTM
Rediscovered and reinterpreted by
HGRN, HGRN2, Gated Linear Attention, Gated State Space models, Bidirectional Gated SSM, Moving Average Equipped Gated Attention, RWKV, Mamba
Covariance Update Rule
mLSTM cell with covariance update rule
Fast Weight Programmers, RWKV-5, RWKV-6, Retention, Linear Transformer, HGRN2
Most Related
Closest model: Rentention, RWKV, HGRN2
Matrix memory & gating
No memory mixing
Memory mixing enables to solve state tracking problems => more expressive than SSMs and Transformers
Residually Stacking Architectures
Almost all contemporary large deep learning models
deep convolution networks and Transformers
(??)
4 Experiments
Evaluation with a focus on language modeling
4.1 specific capability on synthetic tasks
4.2 validation set perplexity (trained on 15B SlimPajama)
4.3 thorough language modeling experiment (300B SlimPajama)
Notation
xLSTM[a:b] -> ratio a/b of mLSTM verses sLSTM
xLSTM[7:1]: 7 are mLSTM-based blocks, one is an sLSTM-based block out of eight blocks
Common total block number: 48
4.1 Synthetic Tasks and Long Range Arena
Test of xLSTM's Exponential Gating with Memory Mixing