|
import torch.nn as nn |
|
import torch |
|
|
|
def quantize_fp8(tensor: torch.Tensor, scale: torch.Tensor): |
|
dtype = tensor.dtype |
|
clamp_min, clamp_max = torch.tensor(-240., dtype=dtype), torch.tensor(240., dtype=dtype) |
|
quant_tensor = torch.clamp((tensor/scale), clamp_min, clamp_max).to(torch.float8_e4m3fnuz).to(dtype) |
|
return quant_tensor |
|
|
|
def dequantize_fp8(tensor: torch.Tensor, scale: torch.Tensor): |
|
return tensor * scale |
|
|
|
|
|
class QuantLinear(nn.Module): |
|
def __init__(self, in_ch, out_ch, quant_param): |
|
super().__init__() |
|
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape']) |
|
self.register_buffer('mul_factor', mul_factor) |
|
self.linear = nn.Linear(in_ch, out_ch) |
|
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape']) |
|
|
|
|
|
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape']) |
|
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape']) |
|
assert quant_param['input_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Input Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['input_zp_dype']}" |
|
self.register_buffer('weight_scale', weight_scale) |
|
|
|
self.register_buffer('input_scale', input_scale) |
|
self.register_buffer('input_zp', input_zp) |
|
|
|
|
|
def qdq_forward(self, x): |
|
print(self.mul_factor.shape) |
|
scaled_x = x * self.mul_factor |
|
quant_weight = quantize_fp8(self.linear.weight, self.weight_scale) |
|
quant_input = quantize_fp8(scaled_x, self.input_scale) |
|
dequantized_weight = dequantize_fp8(quant_weight, self.weight_scale) |
|
dequantized_input = dequantize_fp8(quant_input, self.input_scale) |
|
out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias) |
|
return out |
|
|
|
|
|
def qop_forward(self, x): |
|
quant_weight = quantize_fp8(self.linear.weight, self.weight_scale).to(torch.float8_e4m3fnuz) |
|
fused_input_scale = self.input_scale / self.mul_factor |
|
quant_input = quantize_fp8(x, fused_input_scale).to(torch.float8_e4m3fnuz) |
|
quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.float32) |
|
output = dequantize_fp8(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()])) |
|
output += self.linear.bias |
|
return output |
|
|
|
def forward(self, x, qop=False): |
|
if qop: |
|
return self.qop_forward(x) |
|
else: |
|
return self.qdq_forward(x) |
|
|
|
class QuantConv2d(nn.Module): |
|
def __init__(self, in_ch, out_ch, kernel_size, quant_param): |
|
super().__init__() |
|
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape']) |
|
self.register_buffer('mul_factor', mul_factor) |
|
self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size) |
|
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape']) |
|
|
|
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape']) |
|
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape']) |
|
assert quant_param['input_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Input Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['input_zp_dype']}" |
|
self.register_buffer('weight_scale', weight_scale) |
|
self.register_buffer('input_scale', input_scale) |
|
self.register_buffer('input_zp', input_zp) |
|
|
|
|
|
def qdq_forward(self, x): |
|
scaled_x = x * self.mul_factor |
|
quant_weight = quantize_fp8(self.conv2d.weight, self.weight_scale) |
|
quant_input = quantize_fp8(scaled_x, self.input_scale) |
|
dequantized_weight = dequantize_fp8(quant_weight, self.weight_scale) |
|
dequantized_input = dequantize_fp8(quant_input, self.input_scale) |
|
out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias) |
|
return out |
|
|
|
|
|
def qop_forward(self, x): |
|
quant_weight = quantize_fp8(self.conv2d.weight, self.weight_scale).to(torch.float8_e4m3fnuz) |
|
fused_input_scale = self.input_scale / self.mul_factor |
|
quant_input = quantize_fp8(x, fused_input_scale).to(torch.float8_e4m3fnuz) |
|
quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.float32) |
|
output = dequantize_fp8(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2))) |
|
output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2)) |
|
return output |
|
|
|
def forward(self, x, qop=False): |
|
if qop: |
|
return self.qop_forward(x) |
|
else: |
|
return self.qdq_forward(x) |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
batch_size = 1 |
|
seq_len = 11 |
|
hidden_size = 21 |
|
output_size = 36 |
|
shape = 5 |
|
query = 2.*torch.rand((batch_size, seq_len, hidden_size)) - 1. |
|
conv_input = 2.*torch.rand((batch_size, hidden_size, shape, shape)) - 1. |
|
|
|
quant_params = { |
|
"quant_linear": { |
|
"smoothquant_mul": torch.randn(hidden_size).abs(), |
|
"smoothquant_mul_shape": [1, 1, hidden_size], |
|
"input_scale": torch.max(torch.abs(query)) / 240., |
|
"input_scale_shape": [], |
|
"input_zp": 0.0, |
|
"input_zp_shape": [], |
|
"input_zp_dtype": "torch.float8_e4m3fnuz", |
|
"weight_scale":torch.randn(output_size).abs(), |
|
"weight_scale_shape": [output_size, 1] |
|
}, |
|
"quant_conv": { |
|
"smoothquant_mul": torch.randn(hidden_size).abs(), |
|
"smoothquant_mul_shape": [1, hidden_size, 1, 1], |
|
"input_scale": torch.max(torch.abs(query)) / 240., |
|
"input_scale_shape": [], |
|
"input_zp": 0.0, |
|
"input_zp_shape": [], |
|
"input_zp_dtype": "torch.float8_e4m3fnuz", |
|
"weight_scale":torch.randn(output_size).abs(), |
|
"weight_scale_shape": [output_size, 1, 1, 1] |
|
|
|
} |
|
} |
|
|
|
qlinear = QuantLinear(hidden_size, output_size, quant_params['quant_linear']) |
|
o = qlinear(query) |
|
q_o = qlinear(query, qop=True) |
|
assert torch.allclose(o, q_o) |
|
qconv = QuantConv2d(hidden_size, output_size, shape, quant_params['quant_conv']) |
|
o = qconv(conv_input) |
|
q_o = qconv(conv_input, qop=True) |
|
assert torch.allclose(o, q_o, atol=1e-6) |
|
|