[경량화 챌린지] 18일차 - DeQuantization 구현

ehghkwl·2025년 12월 16일

Lightweight Challenge

목록 보기
18/22

어제 dequantization을 개념으로만 공부했는데, 오늘은 실제 dequantization 코드를 구현해봤다.
dequantization하고 난 이후에 원본값과 얼마나 달라지는지 확인한다.

구현 코드

quantize 함수는 16일차에 코드 있음.

def dequantize_tensor(q, scale, zero_point):
    """
    [미션 1] 역양자화(De-quantization) 구현
    공식: X_recon = scale * (q - zero_point)
    """
    # 1. 정수형 q를 실수형(float32)으로 변환 (계산을 위해 필수!)
    q_float = q.float()
    
    # 2. 공식 적용
    x_recon = (q_float - zero_point) * scale
    
    return x_recon

def get_quantization_error(x_original, x_recon):
    """
    [미션 2] 양자화 오차(MSE) 측정
    MSE = 평균((원본 - 복구본)^2)
    """
    # 1. 차이 계산
    diff = x_original - x_recon
    
    # 2. 제곱 후 평균 (MSE)
    mse = torch.mean(diff**2)
    
    return mse

# --- 통합 테스트 ---
if __name__ == "__main__":
    # 데이터 생성
    data = torch.tensor([-120.0, -10.0, 0.0, 10.0, 120.0, 130.0])
    
    # 1. 양자화 수행 (어제 만든 함수 활용)
    # 테스트를 위해 파라미터를 먼저 구하고 함수를 호출합니다.
    # (비대칭 4비트로 테스트 해봅시다. 오차가 잘 보이거든요!)
    scale, zero_point, q_min, q_max = get_quantization_params(data, 'asymmetric', 4)
    q = quantize_tensor(data, scale, zero_point, q_min, q_max, 4)
        
    print("=== [Process: Original -> Quant -> Dequant] ===")
    print(f"Original : {data}")
    print(f"Quantized: {q}")
    
    # 2. 역양자화 수행 (복구)
    recon_data = dequantize_tensor(q, scale, zero_point)
    print(f"Reconstructed: {recon_data}")
    
    # 3. 오차 측정
    error = get_quantization_error(data, recon_data)
    print(f"MSE Error : {error.item():.6f}")

결과

=== [Process: Original -> Quant -> Dequant] ===
Original : tensor([-120.,  -10.,    0.,   10.,  120.,  130.])
Quantized: tensor([-8, -2, -1,  0,  6,  7], dtype=torch.int8)
Reconstructed: tensor([-116.6666,  -16.6667,    0.0000,   16.6667,  116.6666,  133.3333])
MSE Error : 20.370392
  • 혹시, quantization error도 경량화 논문에서는? 성능 지표로 사용하나?
    최종적으로는 모델 성능으로 하지만, error 를 사용해서, 모델 성능이 떨어진 이유를 찾아낸다.
profile
안녕하세요.

0개의 댓글