As model size grows, the memory bandwidth becomes a major bottleneck.
When deploying these models on distributed systems or multi-device platforms, the inter-device communication overhead impact the inference latency and energy consumption. ⇒Quantization
Most existing quantization → post-training
simple and easy to apply
significant loss of accuracy when precision goes lower (not optimized for quantized representation)
Quantization-aware training
better accuracy
continue-train or fine-tuning
hard to converge when precision goes lower
unknown whether it follows the scaling law of LMs
Previous Trial
CNN
BERT machine translation
BitNet
1-bit transformer architecture
scale efficiently in terms of both memory and computation
binary weights + quantized activations
high precision for optimizer states and gradients
simple implementation
complements other acceleration methods (PagedAttention, FlashAttention, speculative decoding)
Compared with FP16 Transformers (PPL, downstream task)
follows a scaling law of full-precision transformer
2. BitNet
BitLinear instead of matrix multiplication
other components are left as 8-bit precision
residual connection and LayerNorm have negligible computation costs
QKV transformation cost is also much smaller than the parametric projection
models have to use high-precision probabilities to perform sampling (leave input/output embedding as high precision)
2.1 BitLinear
Binarize weights to +1 or -1 with signum function
Centralize to be zero-mean to increase the capacity witnin a limited numerical range
Use scaling factor β after binarization to reduce l2 error between real-valued and the binarized.
W~=Sign(W−α)/β
Sign(Wij)={+1,−1,ifWij>0,ifWij≤0,
α=nm1ij∑Wij
β=nm1∣∣W∣∣1
Quantize activations to b-bit precision with absmax
Qb=2b−1
ϵ is a small floating-point number that prevents overflow in clipping
x~=Quant(x)=Clip(x×γQb,−Qb+ϵ,Qb−ϵ)
γ=∣∣x∣∣∞
For activations before non-linear functions (ReLU) → scale into [0,Qb] by subtracting the minimum of the inputs
x~=Quant(x)=Clip((x−η)×γQb,ϵ,Qb−ϵ)
η=ijminxij
quantize with 8-bit
Training → quantize per tensor / Inference → quantize per token
Matrix Multiplication
y=W~x~
The variance of the output y under following assumption
the elements in W and x are mutually independent and share same distribution