Paper review[LLM-QAT: Data-Free Quantization Aware Training for Large Language Models]

이상민·2024년 12월 5일
0

논문리뷰

목록 보기
21/29

url : https://arxiv.org/pdf/2305.17888


배경)

QAT가 연구되지 않은 이유)

  • LLM 학습 기술적으로 힘듦 & 비용 비쌈
  • QAT는 학습 데이터가 필요 (LLM이 얻긴 힘듦 → pretraining의 규모 & 다양성 자체로 장애물)

의의)

  • QAT를 LLM에 처음 적용, 4-bit quantized LLM 구현
  • 모델 성능 유지하며, Weights, Activations와 KV cache를 동시에 quantization → 긴 시퀀스 생성 시, 처리 속도 bottleneck 현상 해결
  • Original training set의 큰 subset으로 학습한 것과 비교: original model의 output distribution을 더 잘 보존
  • 새로운 Data-free distillation 방법 소개: QAT가 Large pretrained generative model에 실용적으로 적용할 수 있게 만들어줌
  • 7B, 13B, 30B LLaMA모델을 4-bits quantized weights & 4-bits quantized KV cache를 가지고 distillation 가능
  • Activation을 6-bit precision까지 quantization

방법론)

2 측면을 신경써야 함:

  • Quantization 후에도 LLM의 zero-shot generalization 능력이 유지되어야 함 → 적절한 fine-tuning dataset을 선택하는 것이 중요
  • 기존 training set을 완벽히 복사하는 것이 힘듦 (이유: LLM 학습의 규모와 complexity 때문)

방법1) Data-free distillation

Pretrain data의 distribution을 얻기 위해 → 이미 pretrain data로 학습된 pretrained model을 사용해 next token data generation

<start token>을 랜덤하게 뽑은 후, 문장 생성


3가지 sampling strategies 실험)

  • 방법1 : < Next token의 top-1 candidate 뽑기
    • 가장 간단한 방법 >
    • 생성된 문장의 다양성 부족
    • 여러 토큰의 주기적인 반복이 일어남
  • 방법2 : < Pretrained model의 output을 softmax해서 확률로 보고, 확률적으로 sampling >
    • 가장 간단한 방법의 단점을 보완
    • 다양한 문장 생성
    • Student model의 accuracy 향상(결과)
  • 방법3 : < Hybrid Sampling strategy >
    • 배경) 예측 추세를 결정하는데, 처음 몇 개의 토큰이 중요한 역할을 함 → 처음 몇 개의 토큰에 대해서는 ‘정확’하고 ‘확신 있는’ 예측이 필요함
    • 방법) 처음 3~5개의 토큰은 top-1 prediction 사용 & 그 이후에는 확률적으로 sampling

방법2) Quantization-aware training

LLM weight&activation distribution에서는 outlier가 많이 존재 (smaller model이랑 다름) → small model의 SOTA quantization clipping method가 잘 작용하지 않을 수 있음 → LLM에 맞는 적절한 quantizer 식별

Quantization for LLMs

  • Quantization Function

    Outlier가 확실히 안좋은 영향을 주지만, 그렇다고 quantization을 할 때 outlier를 없애는 것도 LLM 성능에 안좋다

    학습의 초기 단계 동안, clipping 기반 방법은 perplexity score를 매우 높힌다(i.e., > 10000) → fine-tuning을 통해 회복하기 힘들 정도로 상당한 정보 손실 유발
    \therefore 이상치를 유지하는 방법 선택

    실험적 Finding) GLU function을 사용하는 모델 → Activations & Weight : Symmetric distribution

    LLaMA model의 activation function: SwiGLU = Swi + GLU

    \therefore Symmetric MinMax quantization 선택 ⇒ Activation & Weight 모두에

    XQi=αXRiα,α=max(XR)2N11\text{X}_\text{Q}^i = \alpha \lfloor \frac{\text{X}_\text{R}^i}{\alpha} \rceil, \quad \alpha = \frac{\max(\text{X}_\text{R})}{2^{N-1}-1}

    XQ\text{X}_\text{Q} : quantized activations or weights, XR\text{X}_\text{R} : real-valued weights or activations

    Efficient quantization을 위해, per-token activation quantization & per-channel weight quantization 채택 (아래 그림)

  • Quantization-aware training for key-value cache

    • Key, Value에도 per-token quantization 사용 (위 그림 (b))
    • 생성 process 동안 현재 key, value가 quantizing 되고 해당 scaling factor는 저장됨
    • QAT 학습 process 동안, key & value의 전체 activation tensors에 quantization 적용 (Figure 2 확인)
  • Knowledge Distillation

    • Cross-Entropy 기반 logits distillation 사용 (teacher model: full-precision & student model: quantized)

      LCE=1nci=1npcT(Xi)log(pcS(Xi))\mathcal{L}_{\text{CE}} = -\frac{1}{n} \sum_c \sum_{i=1}^n p_c^\mathcal{T}(X_i)\log(p_c^\mathcal{S}(X_i))

      ii : 전체 n개의 문장에서 i번째 sample, cc : class의 수(논문의 경우 vocabulary size)

      T,S\mathcal{T,S} : Teacher network, Student network

Summary)

결국 top-1 candidate를 뽑지 않고 distribution에 맞춰 sampling 하기 때문에, student model을 학습시킬 때 사용하는 sampling에 내재된 noise가 들어있을 것 → 학습에 사용하는 next token은 optimal label을 사용하는 것은 아니다

이를 보완하기 위해 teacher model의 soft label(logit)을 활용하여, student model을 학습시 더 풍부한 정보 제공!

Experiments)

논문에서 LLM의 일반화 성능 유지를 중요시 하는 것 같음. 따라서 zero-shot performance와 few-shot performance를 확인

Setting]

Data Generation: LLaMA-7B model 활용 & 생성 문장의 최대 길이 1024

Main Result]

Bits: Weight - Activation - KV

Main result에 대한 느낀점 및 의문점)

낮은 bit로 quantization 했을 때, LLM-QAT가 좋은 성능을 내는 것처럼 보이나 비교군이 적절하지 않은 것 같음

그렇게 생각한 이유) SmoothQuant의 경우 핵심이 W8A8(weight: 8-bit, activation: 8-bit)을 이끌어 내는 것이며, KV cache의 quantization에 대한 내용은 해당 논문에 나오지 않음

→ 그렇기에 SmoothQuant의 weight와 activation이 8-bit인 결과만 비교할 거리가 된다고 생각 & KV cache의 quantization에 대한 내용은 smoothquant에 나오지 않은데, KV의 bit에 따른 비교가 적절한지 모르겠음

따라서 SmoothQuant와 비교하려면 8-8-16을 비교하는 것이 맞다고 보는데, 이는 SmoothQuant가 전반적으로 더 성능이 좋아보임(zero-shot performance의 경우)

LLaMA all 16-bit인 baseline과 비교하는 것은 유의미하다고 보는데, weight, KV cache를 4-bit까지 & activation을 6-bit까지 줄여도 성능이 어느정도 유지된다는 점은 유의미하다고 봄

Data Generation을 LLaMA 7B로 했는데, 그러면 teacher model이 LLaMA 7B인 것일까? 그러면 LLaMA 13B, LLaMA-30B 모델의 결과는 규모가 더 작은 모델이 teacher model이 된 것일까?

Ablation]

  • Data Choice
    Generated 데이터 셋의 규모가 궁금했으나, 확인하지 못했음 (Introduction에서 100k sampling으로 distillation을 성공했다고 언급했으므로, 그정도 규모로 distillation 했을 것으로 추측)
  • Quantization Function
  • Knowledge Distillation

Conclusion & Limitations

  • Data-free quantization-aware training 제안 → QAT를 활용하여 4-bit quantization을 가능하게 함
  • Training-data-agnostic distillation 방법
  • 4-bit quantization는 hardware support가 없음 → hardware implementation을 넣지 못했음
  • 4-bit weight, 4-bit KV cache, 8-bit activation 까지는 작동을 잘 했지만, 4-bit activation quantization에서는 충분하지 않았음

느낀점)

  • 4-bit quantization이 흔하지 않다보니, 논문에서 이를 비교한 부분이 아쉬웠다(SmoothQuant는 4-bit quantization의 비교군으로 적절하지 않다고 봄)

  • 잘 학습된 모델에서 생성하는 data를 이용해 distillation을 하며 quantized model이 일반화 성능을 잃지 않는다는 점이 흥미로웠음

profile
수학, AI, CS study 그리고 일상🤗

0개의 댓글