Bitnet : Scaling 1-bit Transformers for LLMs

임재석·2023년 11월 13일
0

paper-study

목록 보기
2/23

1. Introduction

LLM Hosting

  • High inference cost
  • High energy consumption

As model size grows, the memory bandwidth becomes a major bottleneck.
When deploying these models on distributed systems or multi-device platforms, the inter-device communication overhead impact the inference latency and energy consumption.
\Rightarrow Quantization

Most existing quantization \rightarrow post-training

  • simple and easy to apply
  • significant loss of accuracy when precision goes lower (not optimized for quantized representation)

Quantization-aware training

  • better accuracy
  • continue-train or fine-tuning
  • hard to converge when precision goes lower
  • unknown whether it follows the scaling law of LMs

Previous Trial

  • CNN
  • BERT machine translation

BitNet

  • 1-bit transformer architecture
  • scale efficiently in terms of both memory and computation
  • binary weights + quantized activations
  • high precision for optimizer states and gradients
  • simple implementation
  • complements other acceleration methods (PagedAttention, FlashAttention, speculative decoding)
  • Compared with FP16 Transformers (PPL, downstream task)
  • follows a scaling law of full-precision transformer

2. BitNet

  • BitLinear instead of matrix multiplication
  • other components are left as 8-bit precision
    • residual connection and LayerNorm have negligible computation costs
    • QKV transformation cost is also much smaller than the parametric projection
    • models have to use high-precision probabilities to perform sampling (leave input/output embedding as high precision)

2.1 BitLinear

  1. Binarize weights to +1 or -1 with signum function
  • Centralize to be zero-mean to increase the capacity witnin a limited numerical range
  • Use scaling factor β\beta after binarization to reduce l2 error between real-valued and the binarized.
W~=Sign(Wα)/β\tilde{W} = \rm{Sign} \it{(W - \alpha)} / \beta
Sign(Wij)={+1,if Wij>0,1,if Wij0,\rm{Sign} \it(W_{ij}) = \begin{cases} +1, &&& \text{if} \ \it W_{ij} > \rm 0, \\ -1, &&& \text{if} \ W_{ij} \le 0, \end{cases}
α=1nmijWij\alpha = {1 \over nm} \sum _{ij} W_{ij}
β=1nmW1\beta = {1 \over nm} ||W||_1
  1. Quantize activations to bb-bit precision with absmax
  • Qb=2b1Q_b = 2^{b-1}

  • ϵ\epsilon is a small floating-point number that prevents overflow in clipping

    x~=Quant(x)=Clip(x×Qbγ,Qb+ϵ,Qbϵ)\tilde{x} = \text{Quant}(x) = \text{Clip} \left( x \times {Q_b \over \gamma}, -Q_b + \epsilon , Q_b - \epsilon \right)
    γ=x\gamma = ||x||_{\infin}
  • For activations before non-linear functions (ReLU) \rightarrow scale into [0,Qb][0, Q_b] by subtracting the minimum of the inputs

    x~=Quant(x)=Clip((xη)×Qbγ,ϵ,Qbϵ)\tilde{x} = \text{Quant}(x) = \text{Clip} \left( (x-\eta) \times {Q_b \over \gamma}, \epsilon, Q_b - \epsilon\right)
    η=minijxij\eta = \min _{ij} x_{ij}
  • quantize with 8-bit

  • Training \rightarrow quantize per tensor / Inference \rightarrow quantize per token

  1. Matrix Multiplication
    y=W~x~y = \tilde{W} \tilde{x}

The variance of the output yy under following assumption

  • the elements in WW and xx are mutually independent and share same distribution
  • WW and xx are independent of each other
Var(y)=nVar(w~x~)=nE[w~2]E[x~2]=nβ2E[x~2]E[x~2]\begin{aligned} \text{Var}(y) &= n\text{Var}(\tilde{w}\tilde{x}) \\ &= nE \left[ \tilde{w}^2 \right]E\left[ \tilde{x}^2 \right] \\ &= n \beta^2 E \left[\tilde{x}^2\right] \approx E\left[\tilde{x}^2\right] \end{aligned}

In full-precision, Var(y)=1\text{Var}(y) = 1 with standard initialization method \rightarrow training stability. To preserve this stability, use LayerNorm function.

  • Var(y)E[LN(x~)2]=1(SubLN)\text{Var}(y) \approx E[\text{LN}(\tilde{x})^2] = 1 \quad \quad \quad (\text{SubLN})

Then, the final representation of BitLinear is:

y=W~x~=W~Quant(LN(x))×βγQbLN(x)=xE(x)Var(x)+ϵy = \widetilde{W}\widetilde{x} = \widetilde{W} \text{Quant}(\text{LN}(x)) \times {\beta\gamma \over Q_b} \\ \text{LN} (x) = {x - E(x) \over \sqrt{\text{Var}(x) + \epsilon}}

βγQb{\beta\gamma \over Q_b} means Dequantization to restore original precision

  1. Model Parallelism with Group quantization and Normalization
  • Calculate all parameters α,β,γ,η\alpha, \beta, \gamma, \eta with each group (device)
  • If the Number of group is GG, then the parameter becomes
    αg=GnmijWij(g),βg=GnmW(g)1,γg=x(g),ηg=minijxij(g)\alpha_g = {G \over nm} \sum _{ij} W_{ij} ^{(g)}, \quad \quad \beta_g = {G \over nm} ||W^{(g)}||_1, \\ \gamma_g = ||x^{(g)}||_{\infin}, \quad \quad \eta_g = \min _{ij} x_{ij} ^{(g)}
  • LayerNorm should also be applied with similar way

2.2 Model Training

Straight-through estimator

  • ignore non-differentiable functions (Clip, Sign) during backpropagation

Mixed Precision training

  • weights and activations \rightarrow low precision
  • gradients and optimizer states \rightarrow high precision
  • latent weight \rightarrow high precision (binarized on the fly during forward pass and never used for the inference process)

Large Learning Rate

  • Small update on latent weight makes no difference in the 1-bit weights
  • even worse at the beginning of training
  • BitNet benefits from this method whild FP16 Transformer diverges at the beginning of training

2.3 Computational Efficiency

Arithmetic Operations

Energy Consumption during matrix multiplication

Multiplying m×nm \times n and n×pn \times p matrices

  • Vanila Transformer

    Eadd=m×(n1)×p×E^addEmul=m×n×p×E^mul\begin{aligned} E_{add} &= m \times (n-1) \times p \times \hat{E}_{add} \\ E_{mul} &= m \times n \times p \times \hat{E}_{mul} \end{aligned}
  • BitNet \rightarrow multiplication is dominated by the addition as weights is 1-bit

    Emul=(m×p +m×n)×E^mulE_{mul} = (m \times p \ + m \times n) \times \hat{E}_{mul}

3. Comparision with FP16 Transformers

3.1 Setup

  • 125M to 30B model
  • Sentencepiece with vocab 16K
  • Pile, Common Crawl, RealNews, CC-Stories dataset
  • Transformer baseline

3.2 Inference-Optimal Scaling Law

The loss scales as the power law with the amount of computation

  • Loss gap between FP16 and BitNet gets lower
  • It doesn't properly model the relationship between the loss and the actual compute
  • calculating FLOP doesn't work as the weight is 1-bit
  • mainly measures the training computation

Inference-Optimal Scaline Law

  • Loss against Energy Consumption
  • inference energy cost
  • Significantly better loss and inference cost is much smaller

3.3 Downstream Task

0-shot and 4-shot result

3.4 Stability Test

Varying peak Learning Rates

4. Comparison with Post-training Quantization

4.1 Setup

  • Transformer (W16A16)
  • SmoothQuant(W4A4)
  • GPTQ (W2A16)
  • BitNet(W1A8)
  • QuIP(W2A16)

4.2 Result

  • In 4-bit Models, weight-only quantization methods outperform the weight-and-activation quantizers as activation is harder to quantize
  • In Low-Bit model, significantly achieves better results than other models
  • BitNet has consistently superior scores over all baselines (Quantization-Aware Training > Post training quantization)

5. Ablation studies

  • Quantization method : Absmax, Elastic function
  • LayerNorm : SubLM, Pre-LN, BMT

6. Conslusion

  • 1-bit Transformer
  • scalable and stable
  • Performance on PPL and Downstream task
  • reducing memory footprint, energy consumption
  • Scaling Law

0개의 댓글

관련 채용 정보