k4d3
/

English
Not-For-All-Audiences
yiff_toolkit / scripts /sdp_benchmark.py
k4d3's picture
awoo
4776524
raw
history blame
No virus
1.81 kB
import torch
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
from torch.backends.cuda import sdp_kernel, SDPBackend
device = "cuda" if torch.cuda.is_available() else "cpu"
def benchmark_torch_function_in_milliseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e3 # Convert to milliseconds
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32
dtype = torch.float16
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
print(f"The default implementation runs in {benchmark_torch_function_in_milliseconds(F.scaled_dot_product_attention, query, key, value):.3f} milliseconds")
backend_map = {
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
SDPBackend.EFFICIENT_ATTENTION: {"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}
with sdp_kernel(**backend_map[SDPBackend.MATH]):
print(f"The math implementation runs in {benchmark_torch_function_in_milliseconds(F.scaled_dot_product_attention, query, key, value):.3f} milliseconds")
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
try:
print(f"The memory efficient implementation runs in {benchmark_torch_function_in_milliseconds(F.scaled_dot_product_attention, query, key, value):.3f} milliseconds")
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")