[경량화 챌린지] 20일차 - quantization layer 구현

ehghkwl·2025년 12월 22일

Lightweight Challenge

목록 보기
20/22

nn.linear 를 통해 quantization layer 구현

class QuantLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, n_bits=4):
        # 1. nn.Linear 초기화
        super().__init__(in_features, out_features, bias)
        self.n_bits = n_bits
    
    def forward(self, data):
        output = 0
        """
        1. self.weight를 가져와서 Fake Quantization(Quant -> Dequant)을 수행하세요.
           (주의: input은 건드리지 않습니다. Weight-only Quantization 입니다.)
        2. 변조된 가중치(w_fake)를 사용하여 선형 연산(F.linear)을 수행하고 반환하세요.
        """
        # [Step 1] 가중치 양자화 파라미터 (Scale, Zero Point) 추출 (mode는 'asymmetric')
        scale, zero_point, q_min, q_max = get_quantization_params(self.weight, 'asymmetric', self.n_bits)
        weight_quant = quantize_tensor(self.weight, scale, zero_point, q_min, q_max, self.n_bits)
        
        # [Step 2] Fake Quantization 수행 (Quantization -> Dequantization)
        weight_dequant = dequantize_tensor(weight_quant, scale, zero_point)
        
        # [Step 3] 연산 수행 및 결과 반환
        output = F.linear(data, weight_dequant, self.bias)
        
        return output

if __name__ == "__main__":
    torch.manual_seed(0)
    
    in_f, out_f = 32, 10
    layer = QuantLinear(in_f, out_f, bias=True, n_bits=4)
    
    x = torch.randn(1,in_f)
    
    y = layer(x)
    
    print(f"Layer Type: {type(layer)}")
    print(f"Weight Shape: {layer.weight.shape}") # (10, 32)
    print(f"Output Shape: {y.shape}")            # (1, 10)
    
    # 4. 가중치가 정말 변했는지 확인하기 위해, 강제로 w_fake를 다시 만들어 비교
    with torch.no_grad():
        # 작성자님이 구현한 로직이 맞다면, 아래 오차가 0이 아니어야 함 (정보 손실 발생)
        y_origin = F.linear(x, layer.weight, layer.bias)
        diff = torch.abs(y - y_origin).mean()
        
        print(f"Original Linear vs QuantLinear 차이(MSE): {diff:.6f}")
        if diff > 0:
            print("✅ 축하합니다! Fake Quantization이 정상 적용되었습니다.")
        else:
            print("❌ 실패. 오차가 0입니다. (원본 가중치를 그대로 쓴 것 같습니다)")

구현 도중 겪은 문제는 아래와 같다.

  • 🚨 문제 1: quantization 대상 실수
    이전에 이론적으로 quantization을 배울때는 그냥 data가 들어오면, quantization하면 되겠지 생각하고, forward에서도 그냥 들어온 data에 대해서, parameter를 구하고, 그 값으로 weight를 quantization을 수행했다. 그결과, quantization이 되긴 했지만, 오차가 엄청엄청 컸다.
    Original Linear vs QuantLinear 차이(MSE): 0.730598
    그래서 파라미터 구할때부터 self.weight를 사용해서 하니까, 정상적으로 구해졌다!
    Original Linear vs QuantLinear 차이(MSE): 0.037802
profile
안녕하세요.

1개의 댓글

comment-user-thumbnail
2025년 12월 23일

열심히 하는 모습이 멋있어요. 팬이에요 ㅎㅎ

답글 달기