Upload folder using huggingface_hub
Browse files- debug_sdpa2.py +62 -0
- key_states_sdpa_3.pt +3 -0
- query_states_sdpa_3.pt +3 -0
- value_states_sdpa_3.pt +3 -0
debug_sdpa2.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
|
4 |
+
query = torch.load("query_states_sdpa_3.pt")
|
5 |
+
key = torch.load("key_states_sdpa_3.pt")
|
6 |
+
value = torch.load("value_states_sdpa_3.pt")
|
7 |
+
|
8 |
+
print("query", query.device, query.dtype, query.is_contiguous(), query.shape)
|
9 |
+
print("key", key.device, key.dtype, key.is_contiguous(), key.shape)
|
10 |
+
print("value", value.device, value.dtype, value.is_contiguous(), value.shape)
|
11 |
+
|
12 |
+
torch.set_printoptions(threshold=1000000, precision=6)
|
13 |
+
|
14 |
+
def scaled_dot_product_attention(query, key, value, is_causal: bool, custom_cast: bool):
|
15 |
+
scale_factor = math.sqrt(1 / math.sqrt(query.size(-1)))
|
16 |
+
assert not is_causal
|
17 |
+
|
18 |
+
softmax_inp = ((query * scale_factor) @ ((key * scale_factor).transpose(-2, -1)))
|
19 |
+
|
20 |
+
if custom_cast:
|
21 |
+
attn_weight = torch.softmax(
|
22 |
+
softmax_inp,
|
23 |
+
dim=-1,
|
24 |
+
dtype=torch.float32
|
25 |
+
).to(query.dtype)
|
26 |
+
else:
|
27 |
+
attn_weight = torch.softmax(softmax_inp, dim=-1)
|
28 |
+
|
29 |
+
return attn_weight @ value, softmax_inp
|
30 |
+
|
31 |
+
is_causal = False
|
32 |
+
|
33 |
+
with torch.no_grad():
|
34 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
35 |
+
res_sdpa = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=None)
|
36 |
+
|
37 |
+
res_eager_cast, softmax_inp = scaled_dot_product_attention(query, key, value, is_causal=is_causal, custom_cast=True)
|
38 |
+
|
39 |
+
res_eager_no_cast, _ = scaled_dot_product_attention(query, key, value, is_causal=is_causal, custom_cast=False)
|
40 |
+
|
41 |
+
res_softmax_0 = torch.softmax(softmax_inp, dim=-1)
|
42 |
+
res_softmax_1 = torch.softmax(softmax_inp, dim=-1, dtype=torch.float32).to(torch.float16)
|
43 |
+
|
44 |
+
print("-----")
|
45 |
+
|
46 |
+
absdiff = (res_softmax_0 - res_softmax_1).abs()
|
47 |
+
print("max absdiff softmax", absdiff.max())
|
48 |
+
print("median absdiff softmax", absdiff.median())
|
49 |
+
|
50 |
+
print("-----")
|
51 |
+
|
52 |
+
# These cast do not seem to matter.
|
53 |
+
res_sdpa = res_sdpa.to(torch.float32)
|
54 |
+
res_eager_cast = res_eager_cast.to(torch.float32)
|
55 |
+
res_eager_no_cast = res_eager_no_cast.to(torch.float32)
|
56 |
+
|
57 |
+
absdiff_nocast = (res_sdpa - res_eager_no_cast).abs()
|
58 |
+
absdiff_cast = (res_sdpa - res_eager_cast).abs()
|
59 |
+
print("SDPA max absdiff (no cast):", absdiff_nocast.max())
|
60 |
+
print("SDPA max absdiff (with cast):", absdiff_cast.max())
|
61 |
+
print("argwhere absdiff no cast > 0.0001", torch.argwhere(absdiff_nocast > 1e-4))
|
62 |
+
print("argwhere absdiff with cast > 0.0001", torch.argwhere(absdiff_cast > 1e-4))
|
key_states_sdpa_3.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b3bae5605ebe7843d99b5ca477993c6584544f1825663c26569e3a1b22c1341d
|
3 |
+
size 10142926
|
query_states_sdpa_3.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5dad381f5ce6bbb3a1ec26b8d578d92395de8d57b11860a9c1726f9caa15887a
|
3 |
+
size 10142936
|
value_states_sdpa_3.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:149317f8c4c45572bde47f6aeb704219cd9eccf7ca311d7ec66fb8e08de63219
|
3 |
+
size 30426328
|