class QuantLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True, n_bits=4):
# 1. nn.Linear 초기화
super().__init__(in_features, out_features, bias)
self.n_bits = n_bits
def forward(self, data):
output = 0
"""
1. self.weight를 가져와서 Fake Quantization(Quant -> Dequant)을 수행하세요.
(주의: input은 건드리지 않습니다. Weight-only Quantization 입니다.)
2. 변조된 가중치(w_fake)를 사용하여 선형 연산(F.linear)을 수행하고 반환하세요.
"""
# [Step 1] 가중치 양자화 파라미터 (Scale, Zero Point) 추출 (mode는 'asymmetric')
scale, zero_point, q_min, q_max = get_quantization_params(self.weight, 'asymmetric', self.n_bits)
weight_quant = quantize_tensor(self.weight, scale, zero_point, q_min, q_max, self.n_bits)
# [Step 2] Fake Quantization 수행 (Quantization -> Dequantization)
weight_dequant = dequantize_tensor(weight_quant, scale, zero_point)
# [Step 3] 연산 수행 및 결과 반환
output = F.linear(data, weight_dequant, self.bias)
return output
if __name__ == "__main__":
torch.manual_seed(0)
in_f, out_f = 32, 10
layer = QuantLinear(in_f, out_f, bias=True, n_bits=4)
x = torch.randn(1,in_f)
y = layer(x)
print(f"Layer Type: {type(layer)}")
print(f"Weight Shape: {layer.weight.shape}") # (10, 32)
print(f"Output Shape: {y.shape}") # (1, 10)
# 4. 가중치가 정말 변했는지 확인하기 위해, 강제로 w_fake를 다시 만들어 비교
with torch.no_grad():
# 작성자님이 구현한 로직이 맞다면, 아래 오차가 0이 아니어야 함 (정보 손실 발생)
y_origin = F.linear(x, layer.weight, layer.bias)
diff = torch.abs(y - y_origin).mean()
print(f"Original Linear vs QuantLinear 차이(MSE): {diff:.6f}")
if diff > 0:
print("✅ 축하합니다! Fake Quantization이 정상 적용되었습니다.")
else:
print("❌ 실패. 오차가 0입니다. (원본 가중치를 그대로 쓴 것 같습니다)")
구현 도중 겪은 문제는 아래와 같다.
열심히 하는 모습이 멋있어요. 팬이에요 ㅎㅎ