sdxl-quant-fp8 / attn.py
nickfraser's picture
Added SDPA math model & test
3fea540
raw
history blame
6.26 kB
import math
import torch
import torch.nn as nn
def quantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
dtype = tensor.dtype
clamp_min, clamp_max = torch.tensor(-240., dtype=dtype), torch.tensor(240., dtype=dtype)
quant_tensor = torch.clamp((tensor/scale), clamp_min, clamp_max).to(torch.float8_e4m3fnuz).to(dtype)
return quant_tensor
def dequantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
return tensor * scale
# Based on: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
def qdq_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
query = dequantize_fp8(quantize_fp8(query, query_scale), query_scale)
key = dequantize_fp8(quantize_fp8(key, key_scale), key_scale)
value = dequantize_fp8(quantize_fp8(value, value_scale), value_scale)
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = dequantize_fp8(quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale), softmax_scale)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
def qop_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
query = quantize_fp8(query, query_scale)
key = quantize_fp8(key, key_scale)
value = quantize_fp8(value, value_scale)
# Your quantized kernel starts here
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
scale_factor *= (query_scale * key_scale)
attn_bias = torch.zeros(L, S, dtype=query.dtype)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = (query @ key.transpose(-2, -1)) * scale_factor # or, attn_weight = dequantize_fp8(query @ key.transpose(-2, -1), scale_factor)
attn_weight += attn_bias
attn_weight = quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return (attn_weight @ value) * (softmax_scale * value_scale) # or, return dequantize_fp8(attn_weight @ value, softmax_scale * value_scale)
# Module that implements `torch.nn.functional.scaled_dot_product_attention`
class QuantScaledDotProductAttention(nn.Module):
def __init__(self, quant_param):
super().__init__()
q_scale = torch.tensor(quant_param['out_q']['act_scale']).view(quant_param['out_q']['act_scale_shape'])
k_scale = torch.tensor(quant_param['out_k']['act_scale']).view(quant_param['out_k']['act_scale_shape'])
v_scale = torch.tensor(quant_param['out_v']['act_scale']).view(quant_param['out_v']['act_scale_shape'])
sm_scale = torch.tensor(quant_param['output_softmax_quant']['act_scale']).view(quant_param['output_softmax_quant']['act_scale_shape'])
# Not used, included in model. Kept because we use the zp_dtype as a type hint
#q_zp = torch.tensor(quant_param['out_q']['act_zp']).view(quant_param['out_q']['act_zp_shape'])
#k_zp = torch.tensor(quant_param['out_k']['act_zp']).view(quant_param['out_k']['act_zp_shape'])
#v_zp = torch.tensor(quant_param['out_v']['act_zp']).view(quant_param['out_v']['act_zp_shape'])
assert quant_param['out_q']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Q Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_q']['act_zp_dtype']}"
assert quant_param['out_k']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"K Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_k']['act_zp_dtype']}"
assert quant_param['out_v']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"V Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_v']['act_zp_dtype']}"
assert quant_param['output_softmax_quant']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"SoftMax Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['output_softmax_quant']['act_zp_dtype']}"
self.register_buffer('q_scale', q_scale)
self.register_buffer('k_scale', k_scale)
self.register_buffer('v_scale', v_scale)
self.register_buffer('sm_scale', sm_scale)
# I.e., "fake quantization"
def qdq_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
return qdq_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale)
# Accelerated version
def qop_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
return qop_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale)
def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, qop=False):
if qop:
return self.qop_forward(query, key, value, attn_mask, dropout_p, is_causal, scale)
else:
return self.qdq_forward(query, key, value, attn_mask, dropout_p, is_causal, scale)