어제 dequantization을 개념으로만 공부했는데, 오늘은 실제 dequantization 코드를 구현해봤다.
dequantization하고 난 이후에 원본값과 얼마나 달라지는지 확인한다.
quantize 함수는 16일차에 코드 있음.
def dequantize_tensor(q, scale, zero_point):
"""
[미션 1] 역양자화(De-quantization) 구현
공식: X_recon = scale * (q - zero_point)
"""
# 1. 정수형 q를 실수형(float32)으로 변환 (계산을 위해 필수!)
q_float = q.float()
# 2. 공식 적용
x_recon = (q_float - zero_point) * scale
return x_recon
def get_quantization_error(x_original, x_recon):
"""
[미션 2] 양자화 오차(MSE) 측정
MSE = 평균((원본 - 복구본)^2)
"""
# 1. 차이 계산
diff = x_original - x_recon
# 2. 제곱 후 평균 (MSE)
mse = torch.mean(diff**2)
return mse
# --- 통합 테스트 ---
if __name__ == "__main__":
# 데이터 생성
data = torch.tensor([-120.0, -10.0, 0.0, 10.0, 120.0, 130.0])
# 1. 양자화 수행 (어제 만든 함수 활용)
# 테스트를 위해 파라미터를 먼저 구하고 함수를 호출합니다.
# (비대칭 4비트로 테스트 해봅시다. 오차가 잘 보이거든요!)
scale, zero_point, q_min, q_max = get_quantization_params(data, 'asymmetric', 4)
q = quantize_tensor(data, scale, zero_point, q_min, q_max, 4)
print("=== [Process: Original -> Quant -> Dequant] ===")
print(f"Original : {data}")
print(f"Quantized: {q}")
# 2. 역양자화 수행 (복구)
recon_data = dequantize_tensor(q, scale, zero_point)
print(f"Reconstructed: {recon_data}")
# 3. 오차 측정
error = get_quantization_error(data, recon_data)
print(f"MSE Error : {error.item():.6f}")
=== [Process: Original -> Quant -> Dequant] ===
Original : tensor([-120., -10., 0., 10., 120., 130.])
Quantized: tensor([-8, -2, -1, 0, 6, 7], dtype=torch.int8)
Reconstructed: tensor([-116.6666, -16.6667, 0.0000, 16.6667, 116.6666, 133.3333])
MSE Error : 20.370392