|
import torch |
|
import math |
|
|
|
query = torch.load("query_states_sdpa_3.pt") |
|
key = torch.load("key_states_sdpa_3.pt") |
|
value = torch.load("value_states_sdpa_3.pt") |
|
|
|
print("query", query.device, query.dtype, query.is_contiguous(), query.shape) |
|
print("key", key.device, key.dtype, key.is_contiguous(), key.shape) |
|
print("value", value.device, value.dtype, value.is_contiguous(), value.shape) |
|
|
|
torch.set_printoptions(threshold=1000000, precision=6) |
|
|
|
def scaled_dot_product_attention(query, key, value, is_causal: bool, custom_cast: bool): |
|
scale_factor = math.sqrt(1 / math.sqrt(query.size(-1))) |
|
assert not is_causal |
|
|
|
softmax_inp = ((query * scale_factor) @ ((key * scale_factor).transpose(-2, -1))) |
|
|
|
if custom_cast: |
|
attn_weight = torch.softmax( |
|
softmax_inp, |
|
dim=-1, |
|
dtype=torch.float32 |
|
).to(query.dtype) |
|
else: |
|
attn_weight = torch.softmax(softmax_inp, dim=-1) |
|
|
|
return attn_weight @ value, softmax_inp |
|
|
|
is_causal = False |
|
|
|
with torch.no_grad(): |
|
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): |
|
res_sdpa = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=None) |
|
|
|
res_eager_cast, softmax_inp = scaled_dot_product_attention(query, key, value, is_causal=is_causal, custom_cast=True) |
|
|
|
res_eager_no_cast, _ = scaled_dot_product_attention(query, key, value, is_causal=is_causal, custom_cast=False) |
|
|
|
res_softmax_0 = torch.softmax(softmax_inp, dim=-1) |
|
res_softmax_1 = torch.softmax(softmax_inp, dim=-1, dtype=torch.float32).to(torch.float16) |
|
|
|
print("-----") |
|
|
|
absdiff = (res_softmax_0 - res_softmax_1).abs() |
|
print("max absdiff softmax", absdiff.max()) |
|
print("median absdiff softmax", absdiff.median()) |
|
|
|
print("-----") |
|
|
|
|
|
res_sdpa = res_sdpa.to(torch.float32) |
|
res_eager_cast = res_eager_cast.to(torch.float32) |
|
res_eager_no_cast = res_eager_no_cast.to(torch.float32) |
|
|
|
absdiff_nocast = (res_sdpa - res_eager_no_cast).abs() |
|
absdiff_cast = (res_sdpa - res_eager_cast).abs() |
|
print("SDPA max absdiff (no cast):", absdiff_nocast.max()) |
|
print("SDPA max absdiff (with cast):", absdiff_cast.max()) |
|
print("argwhere absdiff no cast > 0.0001", torch.argwhere(absdiff_nocast > 1e-4)) |
|
print("argwhere absdiff with cast > 0.0001", torch.argwhere(absdiff_cast > 1e-4)) |
|
|