import torch import torch.nn.functional as F import torch.utils.benchmark as benchmark from torch.backends.cuda import sdp_kernel, SDPBackend device = "cuda" if torch.cuda.is_available() else "cpu" def benchmark_torch_function_in_milliseconds(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} ) return t0.blocked_autorange().mean * 1e3 # Convert to milliseconds batch_size = 32 max_sequence_len = 1024 num_heads = 32 embed_dimension = 32 dtype = torch.float16 query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) print(f"The default implementation runs in {benchmark_torch_function_in_milliseconds(F.scaled_dot_product_attention, query, key, value):.3f} milliseconds") backend_map = { SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False}, SDPBackend.EFFICIENT_ATTENTION: {"enable_math": False, "enable_flash": False, "enable_mem_efficient": True} } with sdp_kernel(**backend_map[SDPBackend.MATH]): print(f"The math implementation runs in {benchmark_torch_function_in_milliseconds(F.scaled_dot_product_attention, query, key, value):.3f} milliseconds") with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): try: print(f"The memory efficient implementation runs in {benchmark_torch_function_in_milliseconds(F.scaled_dot_product_attention, query, key, value):.3f} milliseconds") except RuntimeError: print("EfficientAttention is not supported. See warnings for reasons.")