nickfraser
commited on
Commit
•
3fea540
1
Parent(s):
740d40f
Added SDPA math model & test
Browse files- attn.py +101 -0
- test_attn.py +54 -0
attn.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
def quantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
|
8 |
+
dtype = tensor.dtype
|
9 |
+
clamp_min, clamp_max = torch.tensor(-240., dtype=dtype), torch.tensor(240., dtype=dtype)
|
10 |
+
quant_tensor = torch.clamp((tensor/scale), clamp_min, clamp_max).to(torch.float8_e4m3fnuz).to(dtype)
|
11 |
+
return quant_tensor
|
12 |
+
|
13 |
+
def dequantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
|
14 |
+
return tensor * scale
|
15 |
+
|
16 |
+
# Based on: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
17 |
+
def qdq_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
18 |
+
query = dequantize_fp8(quantize_fp8(query, query_scale), query_scale)
|
19 |
+
key = dequantize_fp8(quantize_fp8(key, key_scale), key_scale)
|
20 |
+
value = dequantize_fp8(quantize_fp8(value, value_scale), value_scale)
|
21 |
+
L, S = query.size(-2), key.size(-2)
|
22 |
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
23 |
+
attn_bias = torch.zeros(L, S, dtype=query.dtype)
|
24 |
+
if is_causal:
|
25 |
+
assert attn_mask is None
|
26 |
+
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
27 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
28 |
+
attn_bias.to(query.dtype)
|
29 |
+
|
30 |
+
if attn_mask is not None:
|
31 |
+
if attn_mask.dtype == torch.bool:
|
32 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
33 |
+
else:
|
34 |
+
attn_bias += attn_mask
|
35 |
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
36 |
+
attn_weight += attn_bias
|
37 |
+
attn_weight = dequantize_fp8(quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale), softmax_scale)
|
38 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
39 |
+
return attn_weight @ value
|
40 |
+
|
41 |
+
def qop_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
42 |
+
query = quantize_fp8(query, query_scale)
|
43 |
+
key = quantize_fp8(key, key_scale)
|
44 |
+
value = quantize_fp8(value, value_scale)
|
45 |
+
# Your quantized kernel starts here
|
46 |
+
L, S = query.size(-2), key.size(-2)
|
47 |
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
48 |
+
scale_factor *= (query_scale * key_scale)
|
49 |
+
attn_bias = torch.zeros(L, S, dtype=query.dtype)
|
50 |
+
if is_causal:
|
51 |
+
assert attn_mask is None
|
52 |
+
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
53 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
54 |
+
attn_bias.to(query.dtype)
|
55 |
+
|
56 |
+
if attn_mask is not None:
|
57 |
+
if attn_mask.dtype == torch.bool:
|
58 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
59 |
+
else:
|
60 |
+
attn_bias += attn_mask
|
61 |
+
attn_weight = (query @ key.transpose(-2, -1)) * scale_factor # or, attn_weight = dequantize_fp8(query @ key.transpose(-2, -1), scale_factor)
|
62 |
+
attn_weight += attn_bias
|
63 |
+
attn_weight = quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale)
|
64 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
65 |
+
return (attn_weight @ value) * (softmax_scale * value_scale) # or, return dequantize_fp8(attn_weight @ value, softmax_scale * value_scale)
|
66 |
+
|
67 |
+
# Module that implements `torch.nn.functional.scaled_dot_product_attention`
|
68 |
+
class QuantScaledDotProductAttention(nn.Module):
|
69 |
+
def __init__(self, quant_param):
|
70 |
+
super().__init__()
|
71 |
+
q_scale = torch.tensor(quant_param['out_q']['act_scale']).view(quant_param['out_q']['act_scale_shape'])
|
72 |
+
k_scale = torch.tensor(quant_param['out_k']['act_scale']).view(quant_param['out_k']['act_scale_shape'])
|
73 |
+
v_scale = torch.tensor(quant_param['out_v']['act_scale']).view(quant_param['out_v']['act_scale_shape'])
|
74 |
+
sm_scale = torch.tensor(quant_param['output_softmax_quant']['act_scale']).view(quant_param['output_softmax_quant']['act_scale_shape'])
|
75 |
+
# Not used, included in model. Kept because we use the zp_dtype as a type hint
|
76 |
+
#q_zp = torch.tensor(quant_param['out_q']['act_zp']).view(quant_param['out_q']['act_zp_shape'])
|
77 |
+
#k_zp = torch.tensor(quant_param['out_k']['act_zp']).view(quant_param['out_k']['act_zp_shape'])
|
78 |
+
#v_zp = torch.tensor(quant_param['out_v']['act_zp']).view(quant_param['out_v']['act_zp_shape'])
|
79 |
+
assert quant_param['out_q']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Q Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_q']['act_zp_dtype']}"
|
80 |
+
assert quant_param['out_k']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"K Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_k']['act_zp_dtype']}"
|
81 |
+
assert quant_param['out_v']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"V Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_v']['act_zp_dtype']}"
|
82 |
+
assert quant_param['output_softmax_quant']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"SoftMax Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['output_softmax_quant']['act_zp_dtype']}"
|
83 |
+
self.register_buffer('q_scale', q_scale)
|
84 |
+
self.register_buffer('k_scale', k_scale)
|
85 |
+
self.register_buffer('v_scale', v_scale)
|
86 |
+
self.register_buffer('sm_scale', sm_scale)
|
87 |
+
|
88 |
+
|
89 |
+
# I.e., "fake quantization"
|
90 |
+
def qdq_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
91 |
+
return qdq_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale)
|
92 |
+
|
93 |
+
# Accelerated version
|
94 |
+
def qop_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
95 |
+
return qop_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale)
|
96 |
+
|
97 |
+
def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, qop=False):
|
98 |
+
if qop:
|
99 |
+
return self.qop_forward(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
100 |
+
else:
|
101 |
+
return self.qdq_forward(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
test_attn.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from attn import QuantScaledDotProductAttention
|
5 |
+
|
6 |
+
torch.manual_seed(0)
|
7 |
+
|
8 |
+
batch_size = 1
|
9 |
+
seq_len = 11
|
10 |
+
hidden_size = 21
|
11 |
+
|
12 |
+
query = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1.
|
13 |
+
key = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1.
|
14 |
+
value = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1.
|
15 |
+
|
16 |
+
quant_params = {
|
17 |
+
"output_softmax_quant": {
|
18 |
+
"act_scale": 1./240.,
|
19 |
+
"act_scale_shape": [],
|
20 |
+
"act_zp": 0.0,
|
21 |
+
"act_zp_shape": [],
|
22 |
+
"act_zp_dtype": "torch.float8_e4m3fnuz"
|
23 |
+
},
|
24 |
+
"out_q": {
|
25 |
+
"act_scale": torch.max(torch.abs(query)) / 240.,
|
26 |
+
"act_scale_shape": [],
|
27 |
+
"act_zp": 0.0,
|
28 |
+
"act_zp_shape": [],
|
29 |
+
"act_zp_dtype": "torch.float8_e4m3fnuz"
|
30 |
+
},
|
31 |
+
"out_k": {
|
32 |
+
"act_scale": torch.max(torch.abs(key)) / 240.,
|
33 |
+
"act_scale_shape": [],
|
34 |
+
"act_zp": 0.0,
|
35 |
+
"act_zp_shape": [],
|
36 |
+
"act_zp_dtype": "torch.float8_e4m3fnuz"
|
37 |
+
},
|
38 |
+
"out_v": {
|
39 |
+
"act_scale": torch.max(torch.abs(value)) / 240.,
|
40 |
+
"act_scale_shape": [],
|
41 |
+
"act_zp": 0.0,
|
42 |
+
"act_zp_shape": [],
|
43 |
+
"act_zp_dtype": "torch.float8_e4m3fnuz"
|
44 |
+
},
|
45 |
+
}
|
46 |
+
|
47 |
+
print(quant_params)
|
48 |
+
|
49 |
+
qsdpa = QuantScaledDotProductAttention(quant_params)
|
50 |
+
o_qdq = qsdpa(query, key, value)
|
51 |
+
o_qop = qsdpa(query, key, value, qop=True)
|
52 |
+
print(o_qdq.shape)
|
53 |
+
print(o_qop.shape)
|
54 |
+
print(o_qdq - o_qop)
|