fxmarty HF staff commited on
Commit
fdbb530
1 Parent(s): 9ebd7de

Upload folder using huggingface_hub

Browse files
debug_sdpa2.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ query = torch.load("query_states_sdpa_3.pt")
5
+ key = torch.load("key_states_sdpa_3.pt")
6
+ value = torch.load("value_states_sdpa_3.pt")
7
+
8
+ print("query", query.device, query.dtype, query.is_contiguous(), query.shape)
9
+ print("key", key.device, key.dtype, key.is_contiguous(), key.shape)
10
+ print("value", value.device, value.dtype, value.is_contiguous(), value.shape)
11
+
12
+ torch.set_printoptions(threshold=1000000, precision=6)
13
+
14
+ def scaled_dot_product_attention(query, key, value, is_causal: bool, custom_cast: bool):
15
+ scale_factor = math.sqrt(1 / math.sqrt(query.size(-1)))
16
+ assert not is_causal
17
+
18
+ softmax_inp = ((query * scale_factor) @ ((key * scale_factor).transpose(-2, -1)))
19
+
20
+ if custom_cast:
21
+ attn_weight = torch.softmax(
22
+ softmax_inp,
23
+ dim=-1,
24
+ dtype=torch.float32
25
+ ).to(query.dtype)
26
+ else:
27
+ attn_weight = torch.softmax(softmax_inp, dim=-1)
28
+
29
+ return attn_weight @ value, softmax_inp
30
+
31
+ is_causal = False
32
+
33
+ with torch.no_grad():
34
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
35
+ res_sdpa = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=None)
36
+
37
+ res_eager_cast, softmax_inp = scaled_dot_product_attention(query, key, value, is_causal=is_causal, custom_cast=True)
38
+
39
+ res_eager_no_cast, _ = scaled_dot_product_attention(query, key, value, is_causal=is_causal, custom_cast=False)
40
+
41
+ res_softmax_0 = torch.softmax(softmax_inp, dim=-1)
42
+ res_softmax_1 = torch.softmax(softmax_inp, dim=-1, dtype=torch.float32).to(torch.float16)
43
+
44
+ print("-----")
45
+
46
+ absdiff = (res_softmax_0 - res_softmax_1).abs()
47
+ print("max absdiff softmax", absdiff.max())
48
+ print("median absdiff softmax", absdiff.median())
49
+
50
+ print("-----")
51
+
52
+ # These cast do not seem to matter.
53
+ res_sdpa = res_sdpa.to(torch.float32)
54
+ res_eager_cast = res_eager_cast.to(torch.float32)
55
+ res_eager_no_cast = res_eager_no_cast.to(torch.float32)
56
+
57
+ absdiff_nocast = (res_sdpa - res_eager_no_cast).abs()
58
+ absdiff_cast = (res_sdpa - res_eager_cast).abs()
59
+ print("SDPA max absdiff (no cast):", absdiff_nocast.max())
60
+ print("SDPA max absdiff (with cast):", absdiff_cast.max())
61
+ print("argwhere absdiff no cast > 0.0001", torch.argwhere(absdiff_nocast > 1e-4))
62
+ print("argwhere absdiff with cast > 0.0001", torch.argwhere(absdiff_cast > 1e-4))
key_states_sdpa_3.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3bae5605ebe7843d99b5ca477993c6584544f1825663c26569e3a1b22c1341d
3
+ size 10142926
query_states_sdpa_3.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5dad381f5ce6bbb3a1ec26b8d578d92395de8d57b11860a9c1726f9caa15887a
3
+ size 10142936
value_states_sdpa_3.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:149317f8c4c45572bde47f6aeb704219cd9eccf7ca311d7ec66fb8e08de63219
3
+ size 30426328