import torch import torch.nn as nn from attn import QuantScaledDotProductAttention torch.manual_seed(0) batch_size = 1 seq_len = 11 hidden_size = 21 query = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1. key = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1. value = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1. quant_params = { "output_softmax_quant": { "act_scale": torch.rand((1,))/240., "act_scale_shape": [], "act_zp": 0.0, "act_zp_shape": [], "act_zp_dtype": "torch.float8_e4m3fnuz" }, "out_q": { "act_scale": torch.max(torch.abs(query)) / 240., "act_scale_shape": [], "act_zp": 0.0, "act_zp_shape": [], "act_zp_dtype": "torch.float8_e4m3fnuz" }, "out_k": { "act_scale": torch.max(torch.abs(key)) / 240., "act_scale_shape": [], "act_zp": 0.0, "act_zp_shape": [], "act_zp_dtype": "torch.float8_e4m3fnuz" }, "out_v": { "act_scale": torch.max(torch.abs(value)) / 240., "act_scale_shape": [], "act_zp": 0.0, "act_zp_shape": [], "act_zp_dtype": "torch.float8_e4m3fnuz" }, } print(quant_params) qsdpa = QuantScaledDotProductAttention(quant_params) o_qdq = qsdpa(query, key, value) o_qop = qsdpa(query, key, value, qop=True) print(o_qdq.shape) print(o_qop.shape) print(o_qdq - o_qop)