|
|
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
def quantize_fp8(tensor: torch.Tensor, scale: torch.Tensor): |
|
dtype = tensor.dtype |
|
clamp_min, clamp_max = torch.tensor(-240., dtype=dtype), torch.tensor(240., dtype=dtype) |
|
quant_tensor = torch.clamp((tensor/scale), clamp_min, clamp_max).to(torch.float8_e4m3fnuz).to(dtype) |
|
return quant_tensor |
|
|
|
def dequantize_fp8(tensor: torch.Tensor, scale: torch.Tensor): |
|
return tensor * scale |
|
|
|
|
|
def qdq_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): |
|
query = dequantize_fp8(quantize_fp8(query, query_scale), query_scale) |
|
key = dequantize_fp8(quantize_fp8(key, key_scale), key_scale) |
|
value = dequantize_fp8(quantize_fp8(value, value_scale), value_scale) |
|
L, S = query.size(-2), key.size(-2) |
|
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale |
|
attn_bias = torch.zeros(L, S, dtype=query.dtype) |
|
if is_causal: |
|
assert attn_mask is None |
|
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) |
|
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) |
|
attn_bias.to(query.dtype) |
|
|
|
if attn_mask is not None: |
|
if attn_mask.dtype == torch.bool: |
|
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) |
|
else: |
|
attn_bias += attn_mask |
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor |
|
attn_weight += attn_bias |
|
attn_weight = dequantize_fp8(quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale), softmax_scale) |
|
attn_weight = torch.dropout(attn_weight, dropout_p, train=True) |
|
return attn_weight @ value |
|
|
|
def qop_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): |
|
query = quantize_fp8(query, query_scale) |
|
key = quantize_fp8(key, key_scale) |
|
value = quantize_fp8(value, value_scale) |
|
|
|
L, S = query.size(-2), key.size(-2) |
|
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale |
|
scale_factor *= (query_scale * key_scale) |
|
attn_bias = torch.zeros(L, S, dtype=query.dtype) |
|
if is_causal: |
|
assert attn_mask is None |
|
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) |
|
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) |
|
attn_bias.to(query.dtype) |
|
|
|
if attn_mask is not None: |
|
if attn_mask.dtype == torch.bool: |
|
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) |
|
else: |
|
attn_bias += attn_mask |
|
attn_weight = (query @ key.transpose(-2, -1)) * scale_factor |
|
attn_weight += attn_bias |
|
attn_weight = quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale) |
|
attn_weight = torch.dropout(attn_weight, dropout_p, train=True) |
|
return (attn_weight @ value) * (softmax_scale * value_scale) |
|
|
|
|
|
class QuantScaledDotProductAttention(nn.Module): |
|
def __init__(self, quant_param): |
|
super().__init__() |
|
q_scale = torch.tensor(quant_param['out_q']['act_scale']).view(quant_param['out_q']['act_scale_shape']) |
|
k_scale = torch.tensor(quant_param['out_k']['act_scale']).view(quant_param['out_k']['act_scale_shape']) |
|
v_scale = torch.tensor(quant_param['out_v']['act_scale']).view(quant_param['out_v']['act_scale_shape']) |
|
sm_scale = torch.tensor(quant_param['output_softmax_quant']['act_scale']).view(quant_param['output_softmax_quant']['act_scale_shape']) |
|
|
|
|
|
|
|
|
|
assert quant_param['out_q']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Q Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_q']['act_zp_dtype']}" |
|
assert quant_param['out_k']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"K Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_k']['act_zp_dtype']}" |
|
assert quant_param['out_v']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"V Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_v']['act_zp_dtype']}" |
|
assert quant_param['output_softmax_quant']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"SoftMax Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['output_softmax_quant']['act_zp_dtype']}" |
|
self.register_buffer('q_scale', q_scale) |
|
self.register_buffer('k_scale', k_scale) |
|
self.register_buffer('v_scale', v_scale) |
|
self.register_buffer('sm_scale', sm_scale) |
|
|
|
|
|
|
|
def qdq_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): |
|
return qdq_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale) |
|
|
|
|
|
def qop_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): |
|
return qop_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale) |
|
|
|
def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, qop=False): |
|
if qop: |
|
return self.qop_forward(query, key, value, attn_mask, dropout_p, is_causal, scale) |
|
else: |
|
return self.qdq_forward(query, key, value, attn_mask, dropout_p, is_causal, scale) |
|
|