nickfraser commited on
Commit
3fea540
1 Parent(s): 740d40f

Added SDPA math model & test

Browse files
Files changed (2) hide show
  1. attn.py +101 -0
  2. test_attn.py +54 -0
attn.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ def quantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
8
+ dtype = tensor.dtype
9
+ clamp_min, clamp_max = torch.tensor(-240., dtype=dtype), torch.tensor(240., dtype=dtype)
10
+ quant_tensor = torch.clamp((tensor/scale), clamp_min, clamp_max).to(torch.float8_e4m3fnuz).to(dtype)
11
+ return quant_tensor
12
+
13
+ def dequantize_fp8(tensor: torch.Tensor, scale: torch.Tensor):
14
+ return tensor * scale
15
+
16
+ # Based on: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
17
+ def qdq_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
18
+ query = dequantize_fp8(quantize_fp8(query, query_scale), query_scale)
19
+ key = dequantize_fp8(quantize_fp8(key, key_scale), key_scale)
20
+ value = dequantize_fp8(quantize_fp8(value, value_scale), value_scale)
21
+ L, S = query.size(-2), key.size(-2)
22
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
23
+ attn_bias = torch.zeros(L, S, dtype=query.dtype)
24
+ if is_causal:
25
+ assert attn_mask is None
26
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
27
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
28
+ attn_bias.to(query.dtype)
29
+
30
+ if attn_mask is not None:
31
+ if attn_mask.dtype == torch.bool:
32
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
33
+ else:
34
+ attn_bias += attn_mask
35
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
36
+ attn_weight += attn_bias
37
+ attn_weight = dequantize_fp8(quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale), softmax_scale)
38
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
39
+ return attn_weight @ value
40
+
41
+ def qop_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
42
+ query = quantize_fp8(query, query_scale)
43
+ key = quantize_fp8(key, key_scale)
44
+ value = quantize_fp8(value, value_scale)
45
+ # Your quantized kernel starts here
46
+ L, S = query.size(-2), key.size(-2)
47
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
48
+ scale_factor *= (query_scale * key_scale)
49
+ attn_bias = torch.zeros(L, S, dtype=query.dtype)
50
+ if is_causal:
51
+ assert attn_mask is None
52
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
53
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
54
+ attn_bias.to(query.dtype)
55
+
56
+ if attn_mask is not None:
57
+ if attn_mask.dtype == torch.bool:
58
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
59
+ else:
60
+ attn_bias += attn_mask
61
+ attn_weight = (query @ key.transpose(-2, -1)) * scale_factor # or, attn_weight = dequantize_fp8(query @ key.transpose(-2, -1), scale_factor)
62
+ attn_weight += attn_bias
63
+ attn_weight = quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale)
64
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
65
+ return (attn_weight @ value) * (softmax_scale * value_scale) # or, return dequantize_fp8(attn_weight @ value, softmax_scale * value_scale)
66
+
67
+ # Module that implements `torch.nn.functional.scaled_dot_product_attention`
68
+ class QuantScaledDotProductAttention(nn.Module):
69
+ def __init__(self, quant_param):
70
+ super().__init__()
71
+ q_scale = torch.tensor(quant_param['out_q']['act_scale']).view(quant_param['out_q']['act_scale_shape'])
72
+ k_scale = torch.tensor(quant_param['out_k']['act_scale']).view(quant_param['out_k']['act_scale_shape'])
73
+ v_scale = torch.tensor(quant_param['out_v']['act_scale']).view(quant_param['out_v']['act_scale_shape'])
74
+ sm_scale = torch.tensor(quant_param['output_softmax_quant']['act_scale']).view(quant_param['output_softmax_quant']['act_scale_shape'])
75
+ # Not used, included in model. Kept because we use the zp_dtype as a type hint
76
+ #q_zp = torch.tensor(quant_param['out_q']['act_zp']).view(quant_param['out_q']['act_zp_shape'])
77
+ #k_zp = torch.tensor(quant_param['out_k']['act_zp']).view(quant_param['out_k']['act_zp_shape'])
78
+ #v_zp = torch.tensor(quant_param['out_v']['act_zp']).view(quant_param['out_v']['act_zp_shape'])
79
+ assert quant_param['out_q']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Q Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_q']['act_zp_dtype']}"
80
+ assert quant_param['out_k']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"K Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_k']['act_zp_dtype']}"
81
+ assert quant_param['out_v']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"V Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_v']['act_zp_dtype']}"
82
+ assert quant_param['output_softmax_quant']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"SoftMax Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['output_softmax_quant']['act_zp_dtype']}"
83
+ self.register_buffer('q_scale', q_scale)
84
+ self.register_buffer('k_scale', k_scale)
85
+ self.register_buffer('v_scale', v_scale)
86
+ self.register_buffer('sm_scale', sm_scale)
87
+
88
+
89
+ # I.e., "fake quantization"
90
+ def qdq_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
91
+ return qdq_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale)
92
+
93
+ # Accelerated version
94
+ def qop_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
95
+ return qop_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale)
96
+
97
+ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, qop=False):
98
+ if qop:
99
+ return self.qop_forward(query, key, value, attn_mask, dropout_p, is_causal, scale)
100
+ else:
101
+ return self.qdq_forward(query, key, value, attn_mask, dropout_p, is_causal, scale)
test_attn.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from attn import QuantScaledDotProductAttention
5
+
6
+ torch.manual_seed(0)
7
+
8
+ batch_size = 1
9
+ seq_len = 11
10
+ hidden_size = 21
11
+
12
+ query = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1.
13
+ key = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1.
14
+ value = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1.
15
+
16
+ quant_params = {
17
+ "output_softmax_quant": {
18
+ "act_scale": 1./240.,
19
+ "act_scale_shape": [],
20
+ "act_zp": 0.0,
21
+ "act_zp_shape": [],
22
+ "act_zp_dtype": "torch.float8_e4m3fnuz"
23
+ },
24
+ "out_q": {
25
+ "act_scale": torch.max(torch.abs(query)) / 240.,
26
+ "act_scale_shape": [],
27
+ "act_zp": 0.0,
28
+ "act_zp_shape": [],
29
+ "act_zp_dtype": "torch.float8_e4m3fnuz"
30
+ },
31
+ "out_k": {
32
+ "act_scale": torch.max(torch.abs(key)) / 240.,
33
+ "act_scale_shape": [],
34
+ "act_zp": 0.0,
35
+ "act_zp_shape": [],
36
+ "act_zp_dtype": "torch.float8_e4m3fnuz"
37
+ },
38
+ "out_v": {
39
+ "act_scale": torch.max(torch.abs(value)) / 240.,
40
+ "act_scale_shape": [],
41
+ "act_zp": 0.0,
42
+ "act_zp_shape": [],
43
+ "act_zp_dtype": "torch.float8_e4m3fnuz"
44
+ },
45
+ }
46
+
47
+ print(quant_params)
48
+
49
+ qsdpa = QuantScaledDotProductAttention(quant_params)
50
+ o_qdq = qsdpa(query, key, value)
51
+ o_qop = qsdpa(query, key, value, qop=True)
52
+ print(o_qdq.shape)
53
+ print(o_qop.shape)
54
+ print(o_qdq - o_qop)