[경량화 챌린지] 16일차 - affine quantization 구현

ehghkwl·2025년 12월 14일

Lightweight Challenge

목록 보기
16/22

기존 이론만으로 공부했던 양자화 공식 q=round(X/S+Z)q = round(X/S+Z) 을 실제 pytorch를 사용해 quantization하는 코드를 구현해보았다.

핵심 로직은 다음과 같다.
1. Affine mapping: X를 quantization을 위한 좌표계로 변환
2. Rounding: 소수점 반올림
3. Clipping: 범위를 넘어가는 이상치 삭제
4. Casting: 정수형(낮은 비트)으로 변환

구현 코드

import torch

def quantize_tensor(x, scale, zero_point, q_min, q_max, n_bits):
    """
    [미션] Affine Quantization 공식 구현하기
    공식: q = clip(round(x / scale + zero_point), q_min, q_max)
    """
    # 1. 실수 x를 정수 scale로 변환
    q = (x / scale) + zero_point
    
    # 2. 반올림
    q = torch.round(q)
    
    # 3. quantize 범위를 벗어나는 값 자르기 (q_min ~ q_max)
    q = torch.clamp(q, min = q_min, max = q_max)
    
    # 4. 최종 n_bits 타입으로 변환
    # if (4 == n_bits):
    #     q = q.to(dtype = torch.int4)
    # elif (8 == n_bits):
    #     q = q.to(dtype = torch.int8)
    # else:
    #     raise ValueError("n_bits는 4 또는 8 이어야 합니다.")
    # 4,8 비트 분리해서 해보려하였으나, pytorch에는 int4가 존재하지 않아서 오류가 발생한다 ㅇㅁㅇ
    q = q.to(dtype = torch.int8)
        
    return q

def get_quantization_params(x, mode, n_bits):
    """
    x: 입력 텐서
    mode: 'symmetric' 또는 'asymmetric'
    n_bits: 목표 비트 수 (예: 4 또는 8)
    """
    # 1. q_min, q_max 설정
    q_min = -(2**(n_bits-1))
    q_max = (2**(n_bits-1))-1
    
    # 2. x_min, x_max 확인
    x_min = x.min()
    x_max = x.max()
    
    # 대칭모드
    if ("symmetric" == mode):
        x_abs_max = max(abs(x_min), abs(x_max))
        scale = (2*x_abs_max) / (q_max - q_min + 1e-5)
        zero_point = 0
    # 비대칭 모드
    elif ("asymmetric" == mode):
        scale = (x_max - x_min) / (q_max - q_min + 1e-5)
        zero_point = torch.round(q_max - (x_max/scale))
    else:
        raise ValueError("mode는 'symmetric' 또는 'asymmetric'이어야 합니다.")
    
    return scale, zero_point, q_min, q_max 

def quantize_auto(x, mode="symmetric", n_bits=8):
    scale, zero_point, q_min, q_max = get_quantization_params(x, mode, n_bits)
    print(f"[{mode} / {n_bits}bit] scale: {scale:.4f}, zero_point: {zero_point}")
    
    q = quantize_tensor(x, scale, zero_point, q_min, q_max, n_bits)
    return q

# --- 테스트 코드 ---
if __name__ == "__main__":
    data = torch.tensor([-120.0, -10.0, 0.0, 10.0, 120.0, 130.0])

    # 1. Symmetric INT8 테스트
    print("1. Symmetric INT8 결과:", quantize_auto(data, 'symmetric', 8))
    
    # 2. Asymmetric INT4 테스트
    print("\n2. Asymmetric INT4 결과:", quantize_auto(data, 'asymmetric', 4))

구현 도중 겪은 문제는 아래와 같다.

  • 🚨 문제 1: AttributeError: module 'torch' has no attribute 'int4'
    확인해보니, pytorch에서는 4비트 자료형을 지원하지 않았다.

    • 근데, 왜 int4는 지원하지 않는걸까?
      우선, 우리가 일반적으로 사용하는 cpu/gpu에는 4bit 연산 장치가 없다.
      그래서 pytorch나 bitsandbytes가 int4를 지원하는것은 8비트에 4비트 데이터 2개를 넣고, 사용할때에 unpacking 해서 사용하는 형태이다.
      그리고, 1-bit quantization의 경우에는 torch.int1 이런 자료형이 있는게 아니고, 동일하게 8비트에 1비트 데이터 8개를 packing해서 넣고, CUDA kernel 함수를 만들어서, pytorch에 연결해서 사용하는 형태이다.ㅇㅁㅇ 나중에 구현해서 해보면 좋을거같다.
  • 🚨 문제 2: 비트 범위 계산 실수
    8비트의 범위가 -255~256으로 잘못 수식을 만들었다.
    quantizaiton은 (2N11,2N1)(-2^{N-1}-1, 2^{N-1})이 아니라, (2N1,2N11)(-2^{N-1}, 2^{N-1}-1) 로 범위가 정해지는데, 이거는 2의 보수법 때문이다. 생각해보면, 4비트는 총 16가지를 표현할 수 있다. 그래서 -8~8이면 딱 좋겠지만, 우리에겐 0이라는 숫자가 존재한다. 이 0을 표현하기 위해서 1개의 값을 희생해야하는데 이게 8이라는 숫자이다. 왜 8인고 하면, 컴퓨터는 맨앞의 숫자를 부호로 사용하게 되는데, 0은 양수, 1은 음수이고, 음수는 다음과 같다. (01부호 반대로 하고+1)
    1001 -> 0111 = -7
    1010 -> 0110 = -6
    1011 -> 0101 = -5
    1100 -> 0100 = -4
    1101 -> 0011 = -3
    1110 -> 0010 = -2
    1111 -> 0001 = -1
    1000 -> 1000 = -8
    그리고 양수는 다음과 같다.
    0001 = 1
    0010 = 2
    0011 = 3
    0100 = 4
    0101 = 5
    0110 = 6
    0111 = 7
    0000 = 0

  • 🚨 추가 내용 1: clipping
    PTQ를 위한 s,z를 정하는 샘플 데이터 내에 이상값이 존재하는 경우, quantization이 효율적으로 되지 않을 수 있다. clip는 s,z를 만드는 과정에서 말고, 실제 데이터를 대입하는 경우에 유의미한것으로 보인다.

profile
안녕하세요.

0개의 댓글