|
import torch |
|
import torch.nn as nn |
|
|
|
from attn import QuantScaledDotProductAttention |
|
|
|
torch.manual_seed(0) |
|
|
|
batch_size = 1 |
|
seq_len = 11 |
|
hidden_size = 21 |
|
|
|
query = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1. |
|
key = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1. |
|
value = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1. |
|
|
|
quant_params = { |
|
"output_softmax_quant": { |
|
"act_scale": 1./240., |
|
"act_scale_shape": [], |
|
"act_zp": 0.0, |
|
"act_zp_shape": [], |
|
"act_zp_dtype": "torch.float8_e4m3fnuz" |
|
}, |
|
"out_q": { |
|
"act_scale": torch.max(torch.abs(query)) / 240., |
|
"act_scale_shape": [], |
|
"act_zp": 0.0, |
|
"act_zp_shape": [], |
|
"act_zp_dtype": "torch.float8_e4m3fnuz" |
|
}, |
|
"out_k": { |
|
"act_scale": torch.max(torch.abs(key)) / 240., |
|
"act_scale_shape": [], |
|
"act_zp": 0.0, |
|
"act_zp_shape": [], |
|
"act_zp_dtype": "torch.float8_e4m3fnuz" |
|
}, |
|
"out_v": { |
|
"act_scale": torch.max(torch.abs(value)) / 240., |
|
"act_scale_shape": [], |
|
"act_zp": 0.0, |
|
"act_zp_shape": [], |
|
"act_zp_dtype": "torch.float8_e4m3fnuz" |
|
}, |
|
} |
|
|
|
print(quant_params) |
|
|
|
qsdpa = QuantScaledDotProductAttention(quant_params) |
|
o_qdq = qsdpa(query, key, value) |
|
o_qop = qsdpa(query, key, value, qop=True) |
|
print(o_qdq.shape) |
|
print(o_qop.shape) |
|
print(o_qdq - o_qop) |
|
|