[핵심] [22.08]Optimal Brain Quantizer

YEOM JINSEOP·2024년 8월 5일
post-thumbnail

핵심 아이디어

  • quantization 했을 때 , 전체 Loss에 최소로 영향을 주는 weight를 quantization하고, 나머지 weight들을 update 한다. (greedy하게 각 row 별로)

  • 위 과정을 각 row의 weight가 모두 quantization 될 때 까지 반복한다.


핵심 수식

  1. 전체 Loss: WlXlWl^Xl22||\bold{W}_l\bold{X}_l - \hat{\bold{W}_l}\bold{X}_l||^2_2
    where, given a layer ll, weights Wl\bold{W}_l, layer inputs Xl\bold{X}_l, quantized weights Wl^\hat{\bold{W}_l}
    \\
    Hessian: H=2XXT\bold{H} = 2\bold{X}\bold{X}^T

  2. 각 row에서 quantization할 weight를 고른다.(quantization 시에 Loss에 최소 영향을 주는 weight 선택)

  3. 각 row에서 quantization되지 않은 나머지 weight를 update한다.

  4. 각 row에서 H1H^{-1}에서 qq-th row와 qq-th column을 제거(Hessian Update)한다.

  • 전체적인 작동 알고리즘

code 구현 시각화 (Visualization)

🟨 아래의 layer를 quantization한다고 가정하자.
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 
  • 위 Conv2d layer의 weight WW와 Hessian inverse H1H^{-1} 구하기

  • 여러 row를 parallel하게 (parallel = 32) 연산한다.

    • weight quantization 및 quantization error 나머지 weight update
    • H1H^{-1}, H1H^{-1} diagonal 계산 (이후 [H1]qq[H^{-1}]_{qq} 사용 위함)

  • Loss에 가장 영향을 적게 미치는 weight를 quantization할 wqw_q로 선택

 score = err/diag # score를 계산하고 (Loss에 미치는 영향 계산)
 j = torch.argmin(scores, 1) # 각 row에서 quantize할 column index 하나씩 결정. # (32,)  

  • quantization한 weight 이외의 나머지 weight들을 update
  • H1H^{-1} update H1H^{−1}에서 qq-th row와 qq-th column을 제거

  • 위 과정을 각 row의 모든 column들이 quantization 될 때까지 반복

한 parallel당 O(parallel rows(32)×columns)O(\text{parallel rows}(32) \times \text{columns}) 연산 소요.
코드를 보면, 모든 row에서 0을 갖는 column의 경우 연산에서 제외. (가중치가 0이면 quantization 영향이 없기 때문이다.)

0개의 댓글