sdxl-quant-fp8 / math_model.py
GiusFra's picture
Create math_model.py
6f59b43 verified
import torch.nn as nn
import torch
def quantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
dtype = tensor.dtype
clamp_min, clamp_max = torch.tensor(-240., dtype=dtype), torch.tensor(240., dtype=dtype)
quant_tensor = torch.clamp((tensor/scale), clamp_min, clamp_max).to(torch.float8_e4m3fnuz).to(dtype)
return quant_tensor
def dequantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
return tensor * scale
class QuantLinear(nn.Module):
def __init__(self, in_ch, out_ch, quant_param):
super().__init__()
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
self.register_buffer('mul_factor', mul_factor)
self.linear = nn.Linear(in_ch, out_ch)
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
# weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
# assert quant_param['weight_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Weight Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['weight_zp_dype']}"
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
assert quant_param['input_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Input Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['input_zp_dype']}"
self.register_buffer('weight_scale', weight_scale)
# self.register_buffer('weight_zp', weight_zp)
self.register_buffer('input_scale', input_scale)
self.register_buffer('input_zp', input_zp)
# I.e., "fake quantization"
def qdq_forward(self, x):
print(self.mul_factor.shape)
scaled_x = x * self.mul_factor
quant_weight = quantize_fp8(self.linear.weight, self.weight_scale)
quant_input = quantize_fp8(scaled_x, self.input_scale)
dequantized_weight = dequantize_fp8(quant_weight, self.weight_scale)
dequantized_input = dequantize_fp8(quant_input, self.input_scale)
out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias)
return out
# Accelerated version
def qop_forward(self, x):
quant_weight = quantize_fp8(self.linear.weight, self.weight_scale).to(torch.float8_e4m3fnuz)
fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
quant_input = quantize_fp8(x, fused_input_scale).to(torch.float8_e4m3fnuz)
quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.float32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8
output = dequantize_fp8(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]))
output += self.linear.bias
return output
def forward(self, x, qop=False):
if qop:
return self.qop_forward(x)
else:
return self.qdq_forward(x)
class QuantConv2d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, quant_param):
super().__init__()
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
self.register_buffer('mul_factor', mul_factor)
self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size)
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
assert quant_param['input_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Input Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['input_zp_dype']}"
self.register_buffer('weight_scale', weight_scale)
self.register_buffer('input_scale', input_scale)
self.register_buffer('input_zp', input_zp)
# I.e., "fake quantization"
def qdq_forward(self, x):
scaled_x = x * self.mul_factor
quant_weight = quantize_fp8(self.conv2d.weight, self.weight_scale)
quant_input = quantize_fp8(scaled_x, self.input_scale)
dequantized_weight = dequantize_fp8(quant_weight, self.weight_scale)
dequantized_input = dequantize_fp8(quant_input, self.input_scale)
out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias)
return out
# Accelerated version
def qop_forward(self, x):
quant_weight = quantize_fp8(self.conv2d.weight, self.weight_scale).to(torch.float8_e4m3fnuz)
fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
quant_input = quantize_fp8(x, fused_input_scale).to(torch.float8_e4m3fnuz)
quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.float32) # Convert inputs to FP32 to avoid F.conv2d quantizing the output to int8
output = dequantize_fp8(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2)))
output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2))
return output
def forward(self, x, qop=False):
if qop:
return self.qop_forward(x)
else:
return self.qdq_forward(x)
torch.manual_seed(0)
batch_size = 1
seq_len = 11
hidden_size = 21
output_size = 36
shape = 5
query = 2.*torch.rand((batch_size, seq_len, hidden_size)) - 1.
conv_input = 2.*torch.rand((batch_size, hidden_size, shape, shape)) - 1.
quant_params = {
"quant_linear": {
"smoothquant_mul": torch.randn(hidden_size).abs(),
"smoothquant_mul_shape": [1, 1, hidden_size],
"input_scale": torch.max(torch.abs(query)) / 240.,
"input_scale_shape": [],
"input_zp": 0.0,
"input_zp_shape": [],
"input_zp_dtype": "torch.float8_e4m3fnuz",
"weight_scale":torch.randn(output_size).abs(),
"weight_scale_shape": [output_size, 1]
},
"quant_conv": {
"smoothquant_mul": torch.randn(hidden_size).abs(),
"smoothquant_mul_shape": [1, hidden_size, 1, 1],
"input_scale": torch.max(torch.abs(query)) / 240.,
"input_scale_shape": [],
"input_zp": 0.0,
"input_zp_shape": [],
"input_zp_dtype": "torch.float8_e4m3fnuz",
"weight_scale":torch.randn(output_size).abs(),
"weight_scale_shape": [output_size, 1, 1, 1]
}
}
qlinear = QuantLinear(hidden_size, output_size, quant_params['quant_linear'])
o = qlinear(query)
q_o = qlinear(query, qop=True)
assert torch.allclose(o, q_o)
qconv = QuantConv2d(hidden_size, output_size, shape, quant_params['quant_conv'])
o = qconv(conv_input)
q_o = qconv(conv_input, qop=True)
assert torch.allclose(o, q_o, atol=1e-6)