import torch.nn as nn import torch def quantize_fp8(tensor: torch.Tensor, scale: torch.Tensor): dtype = tensor.dtype clamp_min, clamp_max = torch.tensor(-240., dtype=dtype), torch.tensor(240., dtype=dtype) quant_tensor = torch.clamp((tensor/scale), clamp_min, clamp_max).to(torch.float8_e4m3fnuz).to(dtype) return quant_tensor def dequantize_fp8(tensor: torch.Tensor, scale: torch.Tensor): return tensor * scale class QuantLinear(nn.Module): def __init__(self, in_ch, out_ch, quant_param): super().__init__() mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape']) self.register_buffer('mul_factor', mul_factor) self.linear = nn.Linear(in_ch, out_ch) weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape']) # weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape']) # assert quant_param['weight_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Weight Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['weight_zp_dype']}" input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape']) input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape']) assert quant_param['input_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Input Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['input_zp_dype']}" self.register_buffer('weight_scale', weight_scale) # self.register_buffer('weight_zp', weight_zp) self.register_buffer('input_scale', input_scale) self.register_buffer('input_zp', input_zp) # I.e., "fake quantization" def qdq_forward(self, x): print(self.mul_factor.shape) scaled_x = x * self.mul_factor quant_weight = quantize_fp8(self.linear.weight, self.weight_scale) quant_input = quantize_fp8(scaled_x, self.input_scale) dequantized_weight = dequantize_fp8(quant_weight, self.weight_scale) dequantized_input = dequantize_fp8(quant_input, self.input_scale) out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias) return out # Accelerated version def qop_forward(self, x): quant_weight = quantize_fp8(self.linear.weight, self.weight_scale).to(torch.float8_e4m3fnuz) fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline quant_input = quantize_fp8(x, fused_input_scale).to(torch.float8_e4m3fnuz) quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.float32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8 output = dequantize_fp8(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()])) output += self.linear.bias return output def forward(self, x, qop=False): if qop: return self.qop_forward(x) else: return self.qdq_forward(x) class QuantConv2d(nn.Module): def __init__(self, in_ch, out_ch, kernel_size, quant_param): super().__init__() mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape']) self.register_buffer('mul_factor', mul_factor) self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size) weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape']) input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape']) input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape']) assert quant_param['input_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Input Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['input_zp_dype']}" self.register_buffer('weight_scale', weight_scale) self.register_buffer('input_scale', input_scale) self.register_buffer('input_zp', input_zp) # I.e., "fake quantization" def qdq_forward(self, x): scaled_x = x * self.mul_factor quant_weight = quantize_fp8(self.conv2d.weight, self.weight_scale) quant_input = quantize_fp8(scaled_x, self.input_scale) dequantized_weight = dequantize_fp8(quant_weight, self.weight_scale) dequantized_input = dequantize_fp8(quant_input, self.input_scale) out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias) return out # Accelerated version def qop_forward(self, x): quant_weight = quantize_fp8(self.conv2d.weight, self.weight_scale).to(torch.float8_e4m3fnuz) fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline quant_input = quantize_fp8(x, fused_input_scale).to(torch.float8_e4m3fnuz) quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.float32) # Convert inputs to FP32 to avoid F.conv2d quantizing the output to int8 output = dequantize_fp8(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2))) output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2)) return output def forward(self, x, qop=False): if qop: return self.qop_forward(x) else: return self.qdq_forward(x) torch.manual_seed(0) batch_size = 1 seq_len = 11 hidden_size = 21 output_size = 36 shape = 5 query = 2.*torch.rand((batch_size, seq_len, hidden_size)) - 1. conv_input = 2.*torch.rand((batch_size, hidden_size, shape, shape)) - 1. quant_params = { "quant_linear": { "smoothquant_mul": torch.randn(hidden_size).abs(), "smoothquant_mul_shape": [1, 1, hidden_size], "input_scale": torch.max(torch.abs(query)) / 240., "input_scale_shape": [], "input_zp": 0.0, "input_zp_shape": [], "input_zp_dtype": "torch.float8_e4m3fnuz", "weight_scale":torch.randn(output_size).abs(), "weight_scale_shape": [output_size, 1] }, "quant_conv": { "smoothquant_mul": torch.randn(hidden_size).abs(), "smoothquant_mul_shape": [1, hidden_size, 1, 1], "input_scale": torch.max(torch.abs(query)) / 240., "input_scale_shape": [], "input_zp": 0.0, "input_zp_shape": [], "input_zp_dtype": "torch.float8_e4m3fnuz", "weight_scale":torch.randn(output_size).abs(), "weight_scale_shape": [output_size, 1, 1, 1] } } qlinear = QuantLinear(hidden_size, output_size, quant_params['quant_linear']) o = qlinear(query) q_o = qlinear(query, qop=True) assert torch.allclose(o, q_o) qconv = QuantConv2d(hidden_size, output_size, shape, quant_params['quant_conv']) o = qconv(conv_input) q_o = qconv(conv_input, qop=True) assert torch.allclose(o, q_o, atol=1e-6)