[Paper Review] AWQ: ACTIVATION-AWARE WEIGHT QUANTIZATION FOR ON-DEVICE LLM COMPRESSION AND ACCELERATION

강현구·2024년 6월 24일

Compression

목록 보기
3/3
post-thumbnail

AWQ: ACTIVATION-AWARE WEIGHT QUANTIZATION FOR ON-DEVICE LLM COMPRESSION AND ACCELERATION

1. Background

최근 연구 동향

  • LLM의 Edge Device로의 배포
    • 클라우드 서버로 보내지 않아도 됨. -> Latency 이득
    • 오프라인 가능 -> 실시간 처리 용이
    • 서버가 아닌 로컬에 데이터 저장 -> 데이터 보안 강화
  • Quantization
    • 모델의 bit precision을 줄여 모델 크기를 줄여 추론 속도 이득을 보는 것
    • QAT, PTQ 가 많이 사용되는 추세
    • QAT는 훈련 비용이 많이 들어 비효율적
    • PTQ는 4bit 이하의 low bit에서 정확도 저하가 큼
  • LLM Quantization
    • W8A8 -> W와 A모두 INT8로 양자화
    • W만 양자화 -> GPTQ

LLM Quantization의 한계

  • 너무 큰 메모리
  • 양자화의 효율성 문제
    • QAT : 훈련시 고비용
    • PTQ : 저비트에서 큰 오류
      • GPTQ : Calibration Dataset에 대한 Overfitting
  • 성능 저하 - 중요 가중치가 고르게 분포하지 못함

AWQ 제안

  • backpropagation / reconstruction 에 의존 X (일반화 성능 보장)
  • 단일 정밀도 양자화 방법 제안(W만 양자화) (하드웨어 효율성)
  • 중요 가중치 보호, 가중치 스케일링 등 (양자화 오류 감소)

2. Method

Suggested Algorithm : AWQ

  1. Calibration Dataset을 넣었을 때 각 layer에서의 activation 분석
  2. activation 값이 큰 channel 선택 (0.1~1%)
  3. 식별된 중요한 channel에 대해 Scaling 수행(scaling factor는 초기 값 사용)
  4. scaling factor 최적화 (scaling factor를 greedy 하게 대입하며 quantization loss를 최소화하는 scaling factor 탐색)
  5. 최적화한 scaling factor를 이용하여 quantization 수행

1. Preserve 1% Salient Weights

  • LLM의 Weight들은 동등하게 중요하지 않음.
  • 특정 중요한 Weight를 판단하고, 그 Weight는 양자화하지 않는 전략을 취함.
  • Weight의 중요성은 대개 크기나 L2 Norm 값으로 판단함.
  • 그러나 크기에 따라 Weight 중요성을 판단하여 보존하는 전략은 효과가 없었음.
  • 대신 Activation의 크기에 따라 Weight 중요성을 판단하여 보존하면 큰 효과가 있었음.
  • Weight를 보존하게 되면 Weight에 원본(FP16)과 quantized weight(int3) 이 공존하게 되어, 구현이 매우 어려워짐.
  • 따라서, 중요한 Weight를 Quantize하며 보존하는 방법을 찾기로 함.

2. Protecting Salient Weights by Activation-aware Scaling

Group-Wise Channel Scaling

  • Channel-Wise Scaling 사용
  • Delta는 Scaler를 의미. 좌측 함수로 양자화를 수행(Scaled Round Function)
  • w는 Group 내 가중치를 의미함
  • 초기 스케일러는 그룹 내 가중치 절대값의 최대값을 이용하여 정의.

Update Scaling Factor

  • s는 scaling factor, WX는 calibration data가 원본 가중치W에 들어갔을 때 활성화되는 Weight, Q()는 calibration data가 양자화된 가중치에 들어갔을 때 활성화되는 Weight.
  • 즉 L(s) = |W'X - WX| 를 의미하고, 결국 양자화된 값과 원본 값의 차이를 의미함.
  • 위 수식을 만족하도록 하는 scaling factor s를 grid search로 탐색

3. Result

Performance

  • Llama2, LLaMA 에서 RTN, GPTQ로 Quantize한 것보다 낮은 perplexity 도출.

Generalization

  • 멀티모달 태스크에서 틀리는 경우가 현저히 감소함.

4. Conclusion

Summary

  • 가중치 채널을 보호하고, 양자화 오류를 최소화하여, 다양한 모델에서 높은 성능을 보임
  • 특정 데이터셋에 과적합되지 않도록 설계되어, 여러 도메인에서 일관된 성능을 보임

Limitation

  • Weight만 양자화하고, Activation은 양자화하지 않았음.
profile
고려대학교 인공지능학과 SLP Lab 석사과정생

0개의 댓글