Quantization - 8bit quantization

손기훈·2024년 9월 2일

8-bit quantiztion 요약

A Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using transformers, accelerate and bitsandbytes

quantization배경

  • LLM 등의 대형 언어 모델의 파라미터가 거듭해서 커짐에 따라, 자원을 아끼고 성능은 유지한 채 배포를 하고자 하는 노력이 계속해서 이어지고 있다. 양자화는 이에 대한 노력의 결과라고 할 수 있을 것이다.
  • 양자화의 아이디어는 mixed-precision으로 추론시에 16bit의 데이터 타입을 그대로 써도 성능이 하락하지 않았던 것에서부터 출발한다. 자세한 사항은 이전 포스트인 Mixed Precision Training에서 자세하게 다루고 있다.

학습에 자주 쓰이는 데이터 타입

  • 딥러닝 학습에서 주로 사용되는 데이터 타입은 IEEE-754의 FP32(float, 32bit, 4byte)의 데이터 타입이다. 이때, 총 32bit의 데이터 중 1bit는 부호 비트, 8bit는 지수부, 나머지 23bit는 가수(mantisa)부로 쓰이게 된다. 이는 4byte의 크기 이므로 크기가 절반으로 줄인 데이터(2byte)를 사용하고자 하는 시도가 이루어졌다.
  • mixed-precision 기법이 나오기 이전의 16bit로 실수로 표현하는 과정이 실패했던 까닭은 16bit의 부동소수점의 경우 16비트 중 1bit는 부호 비트, 5bit는 지수부, 10bit가 가수부에 사용된다. 따라서, 지수부에서 3bit 즉 8배 적은 크기의 범위의 데이터만을 다루게 된다. 이때, overflow&underflow가 발생하게 되고 이는 nan 값으로 나타나는 문제가 생기게 되었다.
  • 이를 해결하기 위해 지수부를 8bit로 하는 실수 데이터 타입(BF16)을 사용하였으나 이 경우, 유효숫자를 표기하는 가수부의 bit수가 7bit로 줄어들게 되어 정밀도에서 손해를 보게 되는 경우가 발생하였다. mixed-precision에서는 이 문제를 해결하기 위해 32bit의 가중치는 따로 저장하여 참조하여 사용하고, 순전파/역전파 시에는 16bit를 사용하여 학습 속도를 높이게 되었다.

Mixed Precision

  • mixed-precision의 경우 따로 자세히 다루게 되겠지만, 간단히 설명하자면 FP32 정밀도를 가진 파라미터는 따로 저장해두고 FP16의 정밀도를 순전파/역전파 시에 사용하며, 중요한 연산시에는 다시 FP32의 정밀도를 사용하여 계산하는 과정이다.
  • 다만, 추론 시에도 FP16의 정밀도를 사용하여도 성능의 하락이 보이지 않았다. 이는 사실 FP32의 정밀도가 가중치의 업데이트에만 사용되기 때문이다.

Quantization

위의 과정을 통해 16bit에서 더 나아가 더 적은 데이터 타입으로 가중치를 저장하면서 성능의 하락은 막을 수 있는 방법 대한 연구가 이루어지게 되었다. 이에 대한 방안으로 8bit quantization(이하 8비트 양자화 / 양자화)이 세상에 나오게 되었다.

다만, 들어가기에 앞서, 양자화의 타입 캐스팅 과정은 항상 ‘반올림’의 과정을 동반하게 된다. 예를 들어 0~9 범위의 데이터 타입을 0~4의 범위를 가지는 데이터 타입으로 캐스팅한다고 가정해보자. ‘4’를 0~4로 캐스팅하게 될 경우, 그의 절반인 ‘2’가 될 것이다. 이때, ‘3’을 캐스팅하게 될 경우 1.5 가 될 것이다. 하지만 해당 데이터 타입은 1.5가 없으므로 반올림을 하여 똑같이 ‘2’가 될 것이다. 이 때 데이터는 0.5라는 오차를 갖게 되는 것이다.

양자화는 위와 같은 오차를 줄이면서, 실수 데이터를 정수(int8, char, unsigned char, 8bit, 1byte) 데이터와 매핑 시키는 작업이다. 이제 8비트 양자화에서 주로 사용되는 양자화 기법들을 알아보자.

두 방법들의 공통점이라면 실수를 정수로 매핑하는 과정에서 상수를 사용하여 스케일링(scaling)하는 과정이 필수적으로 동반된다는 점이다.

zero point quantization

  1. 실수형 데이터의 최소값과 최대값을 구한다. rmin,rmaxr_{min}, r_{max}

  2. scale 상수값을 구한다. 이 때, qmax,qminq_{max}, q_{min}의 값은 캐스팅하고자 하는 데이터의 타입의 최대값과 최소값의 범위이다. int8의 경우 -128 ~ 127이 될 것이다.

    scale=qmaxqminrmaxrmin\text{scale} = \frac{q_{max} - q_{min}}{r_{max}-r_{min}}

  3. zero point 계산

    1. zero point는 실수값이 0일 때의 정수값을 의미한다.

      zero point=round(rminscale)qmin\text{zero point} = -round(r_{min}*scale) - q_{min}

    2. 이때 반올림하여 정수값을 표현한다.

  4. 양자화 변환
    quantized vale=round(rscale+zero point)\text{quantized vale} = round(r *{scale} + \text{zero point})

  5. 역양자화
    r=quantized value - zero pointscale\text{r} = \frac{\text{quantized value - \text{zero point}}}{\text{scale}}

  • 예시 실수의 범위 [-1.0 ~ 1.0], 값은 0.5이고 int8로 양자화 한다고 가정할 때
    1. scale 계산
      Scale=qmaxqminrmaxrmin=127(128)1.0(1.0)=127.5\text{Scale} = \frac{q_{max} - q_{min}}{r_{max} - r_{min}} = \frac{127 - (-128)}{1.0 - (-1.0)} = 127.5
    2. 제로포인트계산
      zero point=round(rminscale)qmin=round(1.0127.5)(128)0.0\text{zero point} = -round(r_{min}*scale) - q_{min} = -round(-1.0 * 127.5) - (-128) \approx 0.0
    3. 양자화
      quantized vale=round(rscale+zero point)=round(0.5127.5+0.0)=64\text{quantized vale} = round(r *{scale} + \text{zero point}) = round(0.5 * 127.5 + 0.0) =64
    4. 역양자화
      r=quantized value - zero pointscale=640.0127.5=0.50196\text{r} = \frac{\text{quantized value - \text{zero point}}}{\text{scale}} = \frac{64 - 0.0}{127.5} = 0.50196
  • 코드
    import numpy as np
    
    def get_scale(r, d_type='int8'):
      r_min, r_max = np.min(r), np.max(r)
      d_min = np.iinfo(d_type).min
      d_max = np.iinfo(d_type).max
    
      scale = (d_max - d_min) / (r_max - r_min)
    
      return scale
    
    def get_zeropoint(r, d_type='int8'):
      r_min, r_max = np.min(r), np.max(r)
      d_min = np.iinfo(d_type).min
      d_max = np.iinfo(d_type).max
    
      scale = get_scale(r, d_type)
    
      zero_point = -1 * np.round(r_min * scale) + d_min
    
      return zero_point
    
    def zero_point_quantiztion(r, value,d_type='int8'):
    
      # get scale
      scale = get_scale(r, d_type)
    
      # get zero point
      zero_point = get_zeropoint(r, d_type)
    
      # quantize the value
      q_value = np.round(scale*value + zero_point)
      return q_value.astype(d_type)
    
    def dequantize_value(r, value, d_type='int8'):
      # get scale
      scale = get_scale(r, d_type)
    
      # get zero point
      zero_point = get_zeropoint(r, d_type  )
    
      # dequantize the value
      r_value = (value - zero_point) / scale
      
      return r_value
    
    r = np.linspace(-1.0 , 1.0)
    quantized_value = zero_point_quantiztion(r, 0.5)
    
    print(r)
    print(np.min(r), np.max(r))
    print('scale :', get_scale(r))
    print('zero point :', get_zeropoint(r))
    print(quantized_value)
    print(dequantize_value(r, quantized_value))
    
    '''
    scale : 127.5
    zero point : 0.0
    64
    0.5019607843137255
    '''

absmax quantization

  1. 주어진 텐서의 최대절대값을 구한다.
  2. 캐스팅하려는 범위 값을 최대 절대 값으로 나눠 scaling factor를 구한다.
  3. 값에 scaling factor를 곱한다.
  • 예시 텐서 : [1.2, -0.5, -4.3, 1.2, -3.1, 0.8, 2.4, 5.4] 가 있다고 할 때, int8로 변환한다고 해보자.
    1. 주어진 절대 최대값을 구한다.
      값 : 5.4
    2. 캐스팅하려는 범위 값을 최대 절대 값으로 나눠 scaling factor를 구한다.
      int8 = [-128, 127] = 127,
      scaling_factor = 127 / 5.4 = 23.5
    3. 각 원소에 곱해주고 반올림을 진행한다.
      [28, -12, -101, 28, -73, 19, 56, 127]
  • 코드
    def get_absmax(r):
      return np.max(np.abs(r))
    
    def get_range_data_type(d_type='int8'):
      return np.iinfo(d_type).max
    
    def get_absmax_scale(r, dtype='int8'):
      absmax = get_absmax(r)
      scale = get_range_data_type(dtype) / absmax
      return scale
    
    def absmax_quantization(r, value, dtype='int8'):
      absmax = get_absmax(r)
      scale = get_absmax_scale(r, dtype)
      quantized_value = np.round(value * scale)
      return quantized_value.astype(dtype)
    
    def dequantize_value(r, value, dtype='int8'):
      absmax = get_absmax(r)
      scale = get_absmax_scale(r, dtype)
      r_value = value / scale
      return r_value
    
    def matrix_absmax_quantization(matrix_q, dtype='int8'):
      abs_max = get_absmax(matrix_q)
      scale = get_range_data_type(dtype) / abs_max
      quantized_matrix = np.round(matrix_q * scale)
      return quantized_matrix.astype('int8')
    
    def matrix_dequantize_value(matrix_q, quantized_matrix, dtype='int8'):
      abs_max = get_absmax(matrix_q)
      scale = get_absmax_scale(matrix_q, dtype)
      r_value_matrix = quantized_matrix / scale
      return r_value_matrix
    
    vector_a = np.random.randn(8)
    value = vector_a[3]
    quantized = absmax_quantization(vector_a, value)
    print('vector a :', vector_a)
    print('4번째 원소 :', value)
    print('vector a의 절대 최대 값 :', get_absmax(vector_a))
    print('data의 범위값 :', get_range_data_type())
    print('양자화 값 :', absmax_quantization(vector_a, value))
    print('역 양자화 값', dequantize_value(vector_a,quantized))
    
    '''
    4번째 원소 : -0.19557766858400116
    vector a의 절대 최대 값 : 1.4651296870824921
    data의 범위값 : 127
    양자화 값 : -17
    역 양자화 값 -0.196119721892932
    '''

실제로는 두 방법을 적절히 혼합하여 사용가능하다. 해당 사용에 대한 예시와 방법은 llm.int8 단락에서 간단하게 다루도록 한다.

matirx multiplication with quantization

zeropoint 양자화를 연산시에 사용하려면, Xi8X_{i8}의 모든 원소에다가 zero point를 더하는 연산을 필요로 하고, 이걸 행렬 곱 연산에 적용한다고 때 아래와 같은 경우가 발생하게 된다.

AB=C=(A+zeropointA)(B+zeropointB)A*B = C = (A + zeropoint_{A})(B + zeropoint_{B})
=AB+AzeropointB+BzeropointA+zeropointAzeropointB= AB + A*zeropoint_{B} + B*zeropoint_{A} + zeropoint_{A}*zeropoint_{B}

양자화의 정확도를 위해 zeropoint는 Int16의 정밀도로 계산이 되는데, 이 때 위의 연산을 수행하는 과정에서, 양자화된 AB의 연산을 제외하면 나머지는 16/32비트의 정밀도로 이루어지게 된다. 또한, 추가적인 연산이 발생하게 되므로 연산이 느려지게 되는 결과가 발생하게 된다.


양자화 된 Matrix의 행렬곱 연산의 수식은 다음과 같다

Xf16Wf16=Cf161cxf16cwf16Ci32=Sf16Ci32X_{f16}W_{f16} = C_{f16} ≈ \frac{1}{c_{x_{f16}}c_{w_{f16}}}C_{i32}= S_{f16} · C_{i32}
Sf16Ai8Bi8=Sf16Q(Af16)Q(Bf16)≈ S_{f16} · A_{i8}B_{i8} = S_{f16} · Q(A_{f16}) Q(B_{f16})

다만, 실제로 Xf16Wf16X_{f16}W_{f16}의 결과와 Sf16Ai8Bi8S_{f16} \cdot A_{i8}B_{i8}의 결과는 상당히 큰 오차를 보였다. 하지만, 중간에 Ai8Bi8A_{i8}B_{i8} 의 결과를 int32로 저장하고 다시 Sf16S_{f16} 으로 연산해준 결과는 기존의 양자화 이후의 결과와 큰 오차가 나지 않음을 확인할 수 있었다.

LLM.int8 에 대한 소개

quantization에 있어서 가장 큰 걸림돌은 텐서 당 하나의 스케일링 팩터 상수를 사용한다는 점이었다. 하나의 아웃라이어만 있어도 quantization의 정확도를 낮추기 때문에 문제가 발생했다. quantization의 수식이 min-max scaling과 매우 유사한 점을 보면 이는 쉽게 이해가 가는 문제이다. 따라서, 해당 문제를 해결하기 위해, 텐서당 여러 개의 scaling factor 상수를 사용하여 사용하는 것과 상수의 영향 범위를 제한하는 block-wise quantization 등의 기법이 제안되었다.
논문에서는 기존에 널리 사용되던 row-wise quantization을 개량한 vector-wise quantization과 mixed-precision decomposition의 방안을 제시한다. 실제 llm.int8 에서는 absmax vector-wise quantization과 mixed-precistion decomposition 방식을 혼합하여 사용한다고 한다.

vector-wise quantization

해당 방안은 행렬곱 연산이 여러 번의 내적으로 이루어져있다는 데에서 출발하는 아이디어이다. 입력 행렬 Xf16Rb×hX_{f16} \in R^{b \times h} 라고 할 때 Xf16X_{f16}의 row들 에서 각각 scaling factor를 추출하여 cxf16c_{x_{f16}}을 만들어내고 가중치 행렬 Wf16W_{f16}의 각 column에서 scalingfactor 벡터 $c{w{f16}}$을 추출하여 X와 W의 각 row와 column들의 내적 연산이 수행될 때 각 연산 별로 quantization을 수행한다는 것이다. 그리고 해당 연산이 끝난 후에 dequantization을 수행할 때는 각 $c{x{f16}} \otimes c{w_{f16}}$의 값으로 denormalization이 수행된다는 점이다. 정확한 수식은 아래와 같다.

Cf161cxf16cwf16Ci32=SCi32=SAi8Bi8=SQ(Af16)Q(Bf16)C_{f16} ≈ \frac{1} {c_{x_{f16}} ⊗ c_{w_{f16}}} C_{i32} = S · C_{i32} = S · A_{i8}B_{i8} = S · Q(A_{f16}) Q(B_{f16})

위의 과정을 순서대로 설명하자면 아래와 같다.

  1. 입력 X의 행을 따라서 scaling factor cxc_x를 추출한다.
  2. 가중치 행렬 W의 칼럼을 따라서 scaling factor cwc_w를 추출한다.
  3. X와 W의 행렬곱시 발생하는 각 행 * 열의 곱 마다 퀀타이제이션을 적용한다.
  4. 퀀타이제이션이 적용된 벡터를 다시 해제할 때는 행렬의 내적 결과값에 scaling factor 벡터 cxc_xcwc_w의 외적의 결과를 denormalization하여 해제한다.

mixed-precision decomposition

billion 단위의 파라미터를 가진 모델의 문제점은 성능에 중요한 어마어마한 양의 feature 들이 있고 해당 feature들이 높은 정밀도의 quantization을 요한다는 점이다. 이 점에서 vector-wise quantization은 outlier들이 있는 경우 효과적이지 않다.

논문에 따르면, 이 아웃라이어들이 sparse하고 입력 시퀀스의 차원 대부분에서는 발생하지만 hidden state의 차원내에서는 제한적으로 발생한다고 한다. 따라서, 아웃라이어를 해결하기 위한 새로운 decomposition을 제안해낼 수 있었다고 한다.

아래의 수식은 해당 decompostion에 대한 수식이다.

Cf16hOXhf16Whf16+Sf16hOXhi8Whi8C_{f16} \approx \sum_{h \in O} X_{h_{f16}} W_{h_{f16}} + S_{f16} \cdot \sum_{h \nsubseteq O} X_{h_{i8}} W_{h_{i8}}
  • O : O의 항은 아웃라이어가 발생한 feature dimension의 집합이다.
  • h : hidden dimension의 모든 feature 들이다.

위의 식을 살펴보면 hidden dimension에서 아웃라이어가 발생한 항은 모두 왼쪽 항에서 FP16의 정밀도로 계산을 해주고, 아웃라이어가 발생하지 않은 항은 quantization을 수행한 후에 FP16의 정밀도로 계산해줌을 알 수 있다.

위의 수식을 순서대로 설명하자면 아래와 같다.

  1. 입력 시퀀스 X에서 아웃라이어가 있는 row를 탐지. W에서 해당 row와 같은 인덱스의 칼럼을 골라줌 → 해당 행과 렬을 추출
  2. Outlier X, Outlier W를 FP16 정밀도로 행렬곱 연산 수행
  3. Non Outlier X, Non Outlier W의 행렬곱을 quantization 수행 후 FP16 정밀도로 복원
  4. 두 개의 행렬곱 연산 결과를 더함.

아래는 위의 과정에 대한 논문의 이미지이다.

이하는 Vectorwise Quantization과 Mixed-Precision decomposition의 과정을 numpy만을 사용하여 구현한 예시이다. Vectorwise-quantization에서 Int8 행렬끼리의 연산시 계속해서 오버플로우가 나는 문제가 있었고, 이는 내적 연산시에 먼저 타입 캐스팅하는 것으로 문제를 해결 할 수 있었다.

  • 코드 - vectorwise-quantization
    import numpy as np
    import math
    
    ## vector-wise quantization
    class VectorWiseQuantization:
      def __init__(self, X, W):
        self.X = X
        self.W = W
    
        # scaling factor vectors
        self.C_x = self.get_abs_max(self.X, axis=1) # by row
        self.C_w = self.get_abs_max(self.W, axis=0) # by column
    
        #quantized_x, quantized_w
        self.q_x = self.absmax_quantization_x()
        self.q_w = self.absmax_quantization_w()
    
        #quantized_matrix
        self.quantized_matrix_multiplication = self.quantized_matrix_multiplication(self.q_x, self.q_w)
    
        #dequantized_matrix
        self.dequantized_matrix = self.dequantization()
    
      def get_abs_max(self, matrix, axis=0):
        return np.max(np.abs(matrix), axis=axis)
    
      def get_range_data_type(self, d_type='int8'):
        return float(np.iinfo(d_type).max)
    
      def get_absmax_scale(self, absmax, dtype='int8'):
        scale = self.get_range_data_type(dtype) / absmax
        return scale.astype('float16')
    
      def absmax_quantization_x(self):
        scale = self.get_absmax_scale(self.C_x)
        quantized_x = np.round(self.X * scale[:, np.newaxis] )
        return quantized_x.astype('int8')
    
      def absmax_quantization_w(self):
        scale = self.get_absmax_scale(self.C_w)
        quantized_w = np.round(self.W * scale[np.newaxis, :])
        return quantized_w.astype('int8')
    
      def quantized_matrix_multiplication(self, x, w):
        ## if we do not type-cast before the matmul overflow issue will come out.
        x_32 = x.astype('int32')
        w_32 = w.astype('int32')
        result = np.dot(x_32, w_32)
        return result
    
      def dequantization(self):
        outer_product = np.outer(self.C_x, self.C_w)
        matrix_ = self.quantized_matrix_multiplication * outer_product
        matrix_ = matrix_ / (self.get_range_data_type() ** 2)
        
        return matrix_
    
    np.random.seed(0)
    a = np.random.random((5,5))
    print('X :\n', a)
    
    b = np.random.random((5,5))
    print('W :\n', b)
    
    c = VectorWiseQuantization(a, b)
    print("C_x :", c.C_x)
    print("C_w :", c.C_w)
    
    print("Quantized_X :\n", c.q_x)
    
    print("Quantized_W :\n", c.q_w)
    
    print("Quantized_multiplicated_matrix :\n", c.quantized_matrix_multiplication)
    
    print('outer product of c_x, c_w :\n', np.outer(c.C_x, c.C_w))
    
    print('=' * 50)
    print(c.dequantized_matrix)
    print(np.dot(a, b))
    
    '''
    결과 :
    
    X :
     [[0.5488135  0.71518937 0.60276338 0.54488318 0.4236548 ]
     [0.64589411 0.43758721 0.891773   0.96366276 0.38344152]
     [0.79172504 0.52889492 0.56804456 0.92559664 0.07103606]
     [0.0871293  0.0202184  0.83261985 0.77815675 0.87001215]
     [0.97861834 0.79915856 0.46147936 0.78052918 0.11827443]]
    W :
     [[0.63992102 0.14335329 0.94466892 0.52184832 0.41466194]
     [0.26455561 0.77423369 0.45615033 0.56843395 0.0187898 ]
     [0.6176355  0.61209572 0.616934   0.94374808 0.6818203 ]
     [0.3595079  0.43703195 0.6976312  0.06022547 0.66676672]
     [0.67063787 0.21038256 0.1289263  0.31542835 0.36371077]]
    C_x : [0.71518937 0.96366276 0.92559664 0.87001215 0.97861834]
    C_w : [0.67063787 0.77423369 0.94466892 0.94374808 0.6818203 ]
    Quantized_X :
     [[ 97 127 107  97  75]
     [ 85  58 117 127  51]
     [109  73  78 127  10]
     [ 13   3 122 114 127]
     [127 104  60 101  15]]
    Quantized_W :
     [[121  24 127  70  77]
     [ 50 127  61  77   3]
     [117 100  83 127 127]
     [ 68  72  94   8 124]
     [127  35  17  42  68]]
    Quantized_multiplicated_matrix :
     [[46727 38766 39340 34084 38567]
     [41987 32035 36849 28433 40794]
     [35871 29181 36878 24593 34946]
     [39878 25546 24835 22881 39276]
     [36360 30053 37202 25956 31255]]
    outer product of c_x, c_w :
     [[0.47963307 0.5537237  0.67561716 0.67495859 0.48763063]
     [0.64626874 0.74610017 0.91034226 0.90945488 0.65704483]
     [0.62074016 0.7166281  0.87438237 0.87353005 0.63109058]
     [0.58346309 0.67359272 0.82187343 0.82107229 0.59319194]
     [0.65629852 0.75767929 0.92447033 0.92356918 0.66724185]]
    ==================================================
    [[1.38953528 1.33087315 1.64788761 1.42633075 1.16600226]
     [1.68236627 1.48188475 2.07980667 1.6032321  1.66181951]
     [1.38053011 1.29654192 1.99922334 1.33193158 1.3673564 ]
     [1.44257804 1.0668733  1.26549859 1.16479355 1.44449171]
     [1.47950984 1.41177604 2.13231727 1.48627699 1.29299052]]
    [[1.39270148 1.32860775 1.6512939  1.42824068 1.16938443]
     [1.68347281 1.47905707 2.08164363 1.60639261 1.66607888]
     [1.37780687 1.29014235 1.99450313 1.32804445 1.36853399]
     [1.43857693 1.06090183 1.26023633 1.16403584 1.43948803]
     [1.48261208 1.40749225 2.13347927 1.4847943  1.29890636]]
    '''
  • 코드 - mixed-precision decomposition
    class MixedPrecisionDecomposition:
      def __init__(self, X, W):
        self.X = X
        self.W = W
    
        #outlier matrices
        self.outlier_indices = self.get_outlier_indices(self.X, axis=0)
        self.outlier_x = self.X[:, self.outlier_indices]
        self.outlier_w = self.W[self.outlier_indices, :]
    
        #reuslt of matmul as float16 with outlier matrices
        self.outlier_result = self.dot_outliers()
    
        #non-outlier matrices
        self.non_outlier_indices = np.setdiff1d(np.arange(self.X.shape[0]), self.outlier_indices)
        self.non_outlier_x = self.X[:, self.non_outlier_indices]
        self.non_outlier_w = self.W[self.non_outlier_indices, :]
    
        #quantization matmul matrix
        self.quantized_matrix_multiplication = VectorWiseQuantization(self.non_outlier_x, self.non_outlier_w)
    
        self.dequantized_matrix = self.quantized_matrix_multiplication.dequantized_matrix
    
        #dequantization matrix + f16 outlier matmul matrix
        self.mixed_precision_matrix = self.combine()
    
      def get_outlier_indices(self, matrix, axis=0, threshold=6):
        
        outlier_indices = np.any(np.abs(matrix) > threshold, axis=axis)
        return np.where(outlier_indices)[0]
    
      def dot_outliers(self):
        x_f16 = self.outlier_x.astype('float16')
        w_f16 = self.outlier_w.astype('float16')
        result = np.dot(x_f16, w_f16)
        return result
    
      def combine(self):
        return self.dequantized_matrix + self.outlier_result
        
        
        
     '''
     print(decompostion.mixed_precision_matrix)
     
     [[-16.55756372 -12.26151912  -5.64104254   8.66123818   1.85708373]
     [  4.41882319  -0.71268704  -1.04309271 -12.55671284  -5.27044418]
     [ -0.56348626   2.26975905  -0.82476005 -10.3013141   -2.03545758]
     [ -0.71689122   2.87811143  -8.49162827  -1.14054711   0.15946122]
     [ 11.20172562  -1.7876153   11.89701298 -31.97225726 -13.03379419]]
     
     print(np.dot(random_matrix_x, random_matrix_w))
     
     [[-16.55512064 -12.25679827  -5.64060626   8.66883843   1.85265038]
     [  4.41967919  -0.71840279  -1.04153266 -12.5558374   -5.26652566]
     [ -0.5615892    2.29197021  -0.79650171 -10.30077733  -2.04243318]
     [ -0.71990321   2.85994283  -8.501364    -1.13821841   0.16393162]
     [ 11.19986743  -1.78736878  11.90409671 -31.96087144 -13.03381044]]
     '''
    잘 짜여진 코드는 아니지만 코드 자체의 복잡함은 크게 없기 때문에 금방 이해할 수 있을 것이다. Mixed-Precision decomposition에서 인풋 매트릭스를 분해 할시에 row 방향이 아닌 column 방향으로 해야하는 것만 확인한다면 크게 복잡한 것은 없다. 해당 조치의 이유는 두 개의 행렬곱 연산이 분할 전의 행렬곱 연산의 차원과 같은 크기를 맞추기 위해서이다.
profile
파이썬과 함께라면 두렵지 않아

0개의 댓글