File size: 2,376 Bytes
fdbb530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import math

query = torch.load("query_states_sdpa_3.pt")
key = torch.load("key_states_sdpa_3.pt")
value = torch.load("value_states_sdpa_3.pt")

print("query", query.device, query.dtype, query.is_contiguous(), query.shape)
print("key", key.device, key.dtype, key.is_contiguous(), key.shape)
print("value", value.device, value.dtype, value.is_contiguous(), value.shape)

torch.set_printoptions(threshold=1000000, precision=6)

def scaled_dot_product_attention(query, key, value, is_causal: bool, custom_cast: bool):
    scale_factor = math.sqrt(1 / math.sqrt(query.size(-1)))
    assert not is_causal

    softmax_inp = ((query * scale_factor) @ ((key * scale_factor).transpose(-2, -1)))

    if custom_cast:
        attn_weight = torch.softmax(
            softmax_inp,
            dim=-1,
            dtype=torch.float32
        ).to(query.dtype)
    else:
        attn_weight = torch.softmax(softmax_inp, dim=-1)

    return attn_weight @ value, softmax_inp

is_causal = False

with torch.no_grad():
    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
        res_sdpa = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=None)

        res_eager_cast, softmax_inp = scaled_dot_product_attention(query, key, value, is_causal=is_causal, custom_cast=True)

        res_eager_no_cast, _ = scaled_dot_product_attention(query, key, value, is_causal=is_causal, custom_cast=False)

res_softmax_0 = torch.softmax(softmax_inp, dim=-1)
res_softmax_1 = torch.softmax(softmax_inp, dim=-1, dtype=torch.float32).to(torch.float16)

print("-----")

absdiff = (res_softmax_0 - res_softmax_1).abs()
print("max absdiff softmax", absdiff.max())
print("median absdiff softmax", absdiff.median())

print("-----")

# These cast do not seem to matter.
res_sdpa = res_sdpa.to(torch.float32)
res_eager_cast = res_eager_cast.to(torch.float32)
res_eager_no_cast = res_eager_no_cast.to(torch.float32)

absdiff_nocast = (res_sdpa - res_eager_no_cast).abs()
absdiff_cast = (res_sdpa - res_eager_cast).abs()
print("SDPA max absdiff (no cast):", absdiff_nocast.max())
print("SDPA max absdiff (with cast):", absdiff_cast.max())
print("argwhere absdiff no cast > 0.0001", torch.argwhere(absdiff_nocast > 1e-4))
print("argwhere absdiff with cast > 0.0001", torch.argwhere(absdiff_cast > 1e-4))