기존 이론만으로 공부했던 양자화 공식 을 실제 pytorch를 사용해 quantization하는 코드를 구현해보았다.
핵심 로직은 다음과 같다.
1. Affine mapping: X를 quantization을 위한 좌표계로 변환
2. Rounding: 소수점 반올림
3. Clipping: 범위를 넘어가는 이상치 삭제
4. Casting: 정수형(낮은 비트)으로 변환
import torch
def quantize_tensor(x, scale, zero_point, q_min, q_max, n_bits):
"""
[미션] Affine Quantization 공식 구현하기
공식: q = clip(round(x / scale + zero_point), q_min, q_max)
"""
# 1. 실수 x를 정수 scale로 변환
q = (x / scale) + zero_point
# 2. 반올림
q = torch.round(q)
# 3. quantize 범위를 벗어나는 값 자르기 (q_min ~ q_max)
q = torch.clamp(q, min = q_min, max = q_max)
# 4. 최종 n_bits 타입으로 변환
# if (4 == n_bits):
# q = q.to(dtype = torch.int4)
# elif (8 == n_bits):
# q = q.to(dtype = torch.int8)
# else:
# raise ValueError("n_bits는 4 또는 8 이어야 합니다.")
# 4,8 비트 분리해서 해보려하였으나, pytorch에는 int4가 존재하지 않아서 오류가 발생한다 ㅇㅁㅇ
q = q.to(dtype = torch.int8)
return q
def get_quantization_params(x, mode, n_bits):
"""
x: 입력 텐서
mode: 'symmetric' 또는 'asymmetric'
n_bits: 목표 비트 수 (예: 4 또는 8)
"""
# 1. q_min, q_max 설정
q_min = -(2**(n_bits-1))
q_max = (2**(n_bits-1))-1
# 2. x_min, x_max 확인
x_min = x.min()
x_max = x.max()
# 대칭모드
if ("symmetric" == mode):
x_abs_max = max(abs(x_min), abs(x_max))
scale = (2*x_abs_max) / (q_max - q_min + 1e-5)
zero_point = 0
# 비대칭 모드
elif ("asymmetric" == mode):
scale = (x_max - x_min) / (q_max - q_min + 1e-5)
zero_point = torch.round(q_max - (x_max/scale))
else:
raise ValueError("mode는 'symmetric' 또는 'asymmetric'이어야 합니다.")
return scale, zero_point, q_min, q_max
def quantize_auto(x, mode="symmetric", n_bits=8):
scale, zero_point, q_min, q_max = get_quantization_params(x, mode, n_bits)
print(f"[{mode} / {n_bits}bit] scale: {scale:.4f}, zero_point: {zero_point}")
q = quantize_tensor(x, scale, zero_point, q_min, q_max, n_bits)
return q
# --- 테스트 코드 ---
if __name__ == "__main__":
data = torch.tensor([-120.0, -10.0, 0.0, 10.0, 120.0, 130.0])
# 1. Symmetric INT8 테스트
print("1. Symmetric INT8 결과:", quantize_auto(data, 'symmetric', 8))
# 2. Asymmetric INT4 테스트
print("\n2. Asymmetric INT4 결과:", quantize_auto(data, 'asymmetric', 4))
구현 도중 겪은 문제는 아래와 같다.
🚨 문제 1: AttributeError: module 'torch' has no attribute 'int4'
확인해보니, pytorch에서는 4비트 자료형을 지원하지 않았다.
🚨 문제 2: 비트 범위 계산 실수
8비트의 범위가 -255~256으로 잘못 수식을 만들었다.
quantizaiton은 이 아니라, 로 범위가 정해지는데, 이거는 2의 보수법 때문이다. 생각해보면, 4비트는 총 16가지를 표현할 수 있다. 그래서 -8~8이면 딱 좋겠지만, 우리에겐 0이라는 숫자가 존재한다. 이 0을 표현하기 위해서 1개의 값을 희생해야하는데 이게 8이라는 숫자이다. 왜 8인고 하면, 컴퓨터는 맨앞의 숫자를 부호로 사용하게 되는데, 0은 양수, 1은 음수이고, 음수는 다음과 같다. (01부호 반대로 하고+1)
1001 -> 0111 = -7
1010 -> 0110 = -6
1011 -> 0101 = -5
1100 -> 0100 = -4
1101 -> 0011 = -3
1110 -> 0010 = -2
1111 -> 0001 = -1
1000 -> 1000 = -8
그리고 양수는 다음과 같다.
0001 = 1
0010 = 2
0011 = 3
0100 = 4
0101 = 5
0110 = 6
0111 = 7
0000 = 0
🚨 추가 내용 1: clipping
PTQ를 위한 s,z를 정하는 샘플 데이터 내에 이상값이 존재하는 경우, quantization이 효율적으로 되지 않을 수 있다. clip는 s,z를 만드는 과정에서 말고, 실제 데이터를 대입하는 경우에 유의미한것으로 보인다.