LLM Acceleration
Reducing the number of layers for each token by exiting early during inference
Speculative decoding
main model + draft model
larger memory footprint and complexity
faster inference
Self-Speculative Decoding
contribution
training recipe that combines layer dropout and early exit loss
the recipe more robust to exiting at earlier layers of the model, essentially creating different sized sub-models within the same model
self-speculative decoding solution that decodes with earlier layers and verifies and corrects with later layers
Fig 2a -> Llama1 7B + HumanEval coding dataset
projected each layer's output embeddings on the LM head + softmax got the index of the output element (Unembedding)
token predictions in earlier layers appear to be irrelevant
in later layers, token predictions converge to the final prediction
most of the time, the final token predition is predicted fewer layers before the end
intermediate layers are sometimes hesitant and change their mind
a token requires 23.45 layers out of the model's 32 layers
need to make the model to use fewer layers
make the model not to hesitate and change their mind
skipping layers during training (dropout)
unembedding
typically LLMs are trained to unembed at the last transformer layer
need to adds a loss function during training to make the LM heads understand embeddings of earlier layer
shared LM head to early exit
make the LM head as ensemble of different depth models with same weight
unstructured dropout (original)
large models (Llama, GPT3, PaLM) don't use it at large corpus
enable the training to learn across an ensemble of many models
multiplicative noise
stochastically skipping layers
LayerDrop in LMs robustness
layer dropout for training decoder-only models or scaling LMs has not beed explored
branch modules at different exit points in a deep learning network + additional loss
in LMs, early exit in encoder-only models was explored
dedicated LM head for each decoder layer
SkipDecode
additional FC layer
auto-regressive decoding is slow while measuring the likelihood of a group of generated tokens in parallel is faster
draft model (fast, less accurate) to generate tokens and verify and correct with main (slow, more accurate) model
LM head should be capable of unembedding outputs of different layers
During training, supervise the model directly to connect the early exit layers to the LM head
adding early exit loss of all layers at all iteration slows down the training and reduces the accuracy
use
rotational early exit curriculum
gradual early exit curriculum
run the first transformer layers and skip to the model's LM head
the final output is
use single model and latency of traditional speculative decoding
Self Drafting and Self-Verification
Self Drafting: using the early exit to draft tokens
Self Verification: using the remaining layers to validate the prediction
Cache Reuse : unifies the KV cache and storing the exit query
leverages the full LLM to predict the next token for each draft token in a single forward pass
find the point where the draft tokens and verified tokens agree
All the draft tokens up till the disagreement point are added to the output along with the next verified token and continues from the draft
only computes layers
continue training with 52B tokens
text + code
Llama2 7B (32 layers)
Llama2 13B (40 layers)
26B tokens
text + code
Llama2 1.5B (24 layers)
Llama2 7B (32 layers)
higher LR when dropout 0.0
5.2B tokens
Llama1 7B
TOPv2 dataset
Llama 1.5B (24 layers)
LayerSkip is better than the baseline
for the last layer accuracy, LayerSkip has minimal drop in accuracy
some classification tasks (multiple choice, TF) maintain relatively decent accuracy on earlier layers
generation task drop drastically
classification is evaluated on one token while generation is evaluated on many tokens
in MMLU, Llama2 13B baseline dropped from 55.2 to 49.2
NaturalQuestions LayerSkip's accuracy is higher at middle layer
Fig 10a
earlier layers are better than the baseline
LD+EE shows a big improvement
this is specific domain data, scaled to 1.0
Fig 10b
removing layers from the baseline, the model is not able to generate complete and accurate parses 0 EM
LayerSkip shows 77% at layer 12
regression in the final layer reducing accuracy by 3%
used EM, ROUGE-2
compared with common models and tasks in Draft & Verify
used greedy decoding and max 512 tokens
50000 steps
batch size per device: 4
context window: 4096
number of GPUs: 32, 64, 128
middle layer PPL increases by default (w/o EE)
could open door about the dynamics of transformers
self-speculative decoding doesn't require changing a model's weights
, , need to be tuned
pretraining with layer dropout from scratch, increasing LR is needed and tuning LR is tricky
layer dropout + early exit loss improves accuracy and speed
hope this to be combined with PEFT
in the future, increasing the accuracy of early-exit layers and exploring dynamic conditions to determine a different exit layer can be done
Pruning과 다르게 선택적인 레이어만 사용하여 학습과 추론을 하는 것, 그리고 남는 레이어를 이용해 Self-Speculative Decoding을 알차게 구현한 기법. Transformer에서는 결국 모든 레이어가 필요치 않은듯. CoT without prompting과 결이 비슷한 듯.