+

Flash Attention Implementation

+

GPU Info

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: nv | 0.26s + | + +Raw +GitHub +
+
+
+
import subprocess
+
+print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
+
+ +
+
+
+
+
Wed Oct 29 00:36:31 2025       
++-----------------------------------------------------------------------------------------+
+| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
+|-----------------------------------------+------------------------+----------------------+
+| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
+| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
+|                                         |                        |               MIG M. |
+|=========================================+========================+======================|
+|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
+| N/A   32C    P0            151W /  350W |       0MiB /  46068MiB |     86%      Default |
+|                                         |                        |                  N/A |
++-----------------------------------------+------------------------+----------------------+
+
++-----------------------------------------------------------------------------------------+
+| Processes:                                                                              |
+|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
+|        ID   ID                                                               Usage      |
+|=========================================================================================|
+|  No running processes found                                                             |
++-----------------------------------------------------------------------------------------+
+
+
+
+
+ +

Flash Attention Benchmark

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark | 3.81s + | + +Raw +GitHub +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch==2.8.0",
+#     "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
+# ///
+import torch
+import sys
+from kernels_benchmark_tools import KernelTypeEnum, run_benchmark
+
+
+def torch_flash(q, k, v):
+    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
+    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
+        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
+    return o.transpose(1, 2).contiguous()
+
+
+run_benchmark(
+    kernel_type=KernelTypeEnum.ATTENTION,
+    impl_name="torch_flash_ma",
+    impl_tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
+    impl_func=torch_flash,
+)
+
+ +
+
+
+
+
Running attention benchmark on cuda with 6 workloads.
+
+======================================================================
+PROFILE TRACE: torch_flash_ma | cuda_attn_L128_bfloat16
+======================================================================
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.578ms       102.17%       3.578ms       3.578ms             1  
+                                         torch_flash_ma         6.87%     353.422us        46.38%       2.386ms       2.386ms       0.000us         0.00%       3.542ms       3.542ms             1  
+                     aten::scaled_dot_product_attention         0.81%      41.691us         4.31%     221.887us      73.962us       0.000us         0.00%       2.788ms     929.262us             3  
+              aten::_scaled_dot_product_flash_attention         0.53%      27.420us         3.50%     180.196us      60.065us       0.000us         0.00%       2.788ms     929.262us             3  
+                         aten::_flash_attention_forward         0.77%      39.803us         2.56%     131.456us      43.819us       2.788ms        79.61%       2.788ms     929.262us             3  
+void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.788ms        79.61%       2.788ms     929.262us             3  
+                                       aten::contiguous         0.28%      14.581us        33.97%       1.748ms     145.626us       0.000us         0.00%     754.272us      62.856us            12  
+                                            aten::clone         0.77%      39.360us        33.69%       1.733ms     144.411us       0.000us         0.00%     754.272us      62.856us            12  
+                                            aten::copy_         1.64%      84.313us        31.38%       1.614ms     134.494us     713.920us        20.39%     754.272us      62.856us            12  
+void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     713.920us        20.39%     713.920us      59.493us            12  
+                                Activity Buffer Request        27.68%       1.424ms        27.68%       1.424ms       1.424ms      40.352us         1.15%      40.352us      40.352us             1  
+                                        aten::transpose         1.22%      62.617us         1.64%      84.135us       3.506us       0.000us         0.00%       0.000us       0.000us            24  
+                                       aten::as_strided         0.42%      21.518us         0.42%      21.518us       0.897us       0.000us         0.00%       0.000us       0.000us            24  
+                                       aten::empty_like         0.49%      25.079us         1.99%     102.243us       6.816us       0.000us         0.00%       0.000us       0.000us            15  
+                                            aten::empty         1.77%      91.033us         1.77%      91.033us       3.793us       0.000us         0.00%       0.000us       0.000us            24  
+                                       cudaLaunchKernel         2.57%     132.402us         2.57%     132.402us       8.827us       0.000us         0.00%       0.000us       0.000us            15  
+                                    aten::empty_strided         0.32%      16.702us         0.32%      16.702us       5.567us       0.000us         0.00%       0.000us       0.000us             3  
+                                 cudaDeviceGetAttribute         0.05%       2.750us         0.05%       2.750us       0.458us       0.000us         0.00%       0.000us       0.000us             6  
+                                   cudaFuncSetAttribute         0.17%       9.001us         0.17%       9.001us       3.000us       0.000us         0.00%       0.000us       0.000us             3  
+                                  cudaDeviceSynchronize        53.62%       2.758ms        53.62%       2.758ms       2.758ms       0.000us         0.00%       0.000us       0.000us             1  
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+Self CPU time total: 5.144ms
+Self CUDA time total: 3.502ms
+
+
+
+======================================================================
+PROFILE TRACE: torch_flash_ma | cuda_attn_L256_bfloat16
+======================================================================
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+                                         torch_flash_ma         4.93%     257.698us        42.06%       2.199ms       2.199ms       0.000us         0.00%       3.742ms       3.742ms             1  
+                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.698ms       100.30%       3.698ms       3.698ms             1  
+                     aten::scaled_dot_product_attention         0.48%      25.212us         3.48%     182.067us      60.689us       0.000us         0.00%       2.929ms     976.488us             3  
+              aten::_scaled_dot_product_flash_attention         0.39%      20.471us         3.00%     156.855us      52.285us       0.000us         0.00%       2.929ms     976.488us             3  
+                         aten::_flash_attention_forward         0.74%      38.430us         2.18%     114.074us      38.025us       2.929ms        79.45%       2.929ms     976.488us             3  
+void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.929ms        79.45%       2.929ms     976.488us             3  
+                                       aten::contiguous         0.17%       9.122us        32.76%       1.713ms     142.713us       0.000us         0.00%     812.318us      67.693us            12  
+                                            aten::clone         0.59%      31.068us        32.59%       1.703ms     141.953us       0.000us         0.00%     812.318us      67.693us            12  
+                                            aten::copy_         1.50%      78.513us        30.83%       1.612ms     134.315us     757.726us        20.55%     812.318us      67.693us            12  
+void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     757.726us        20.55%     757.726us      63.144us            12  
+                                Activity Buffer Request        27.74%       1.450ms        27.74%       1.450ms       1.450ms      54.592us         1.48%      54.592us      54.592us             1  
+                                        aten::transpose         0.99%      51.637us         1.32%      68.781us       2.866us       0.000us         0.00%       0.000us       0.000us            24  
+                                       aten::as_strided         0.33%      17.144us         0.33%      17.144us       0.714us       0.000us         0.00%       0.000us       0.000us            24  
+                                       aten::empty_like         0.41%      21.274us         1.52%      79.248us       5.283us       0.000us         0.00%       0.000us       0.000us            15  
+                                            aten::empty         1.40%      73.206us         1.40%      73.206us       3.050us       0.000us         0.00%       0.000us       0.000us            24  
+                                       cudaLaunchKernel         2.03%     106.061us         2.03%     106.061us       7.071us       0.000us         0.00%       0.000us       0.000us            15  
+                                    aten::empty_strided         0.26%      13.410us         0.26%      13.410us       4.470us       0.000us         0.00%       0.000us       0.000us             3  
+                                 cudaDeviceGetAttribute         0.04%       1.900us         0.04%       1.900us       0.317us       0.000us         0.00%       0.000us       0.000us             6  
+                                   cudaFuncSetAttribute         0.07%       3.830us         0.07%       3.830us       1.277us       0.000us         0.00%       0.000us       0.000us             3  
+                                  cudaDeviceSynchronize        57.94%       3.028ms        57.94%       3.028ms       3.028ms       0.000us         0.00%       0.000us       0.000us             1  
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+Self CPU time total: 5.227ms
+Self CUDA time total: 3.687ms
+
+
+
+======================================================================
+PROFILE TRACE: torch_flash_ma | cuda_attn_L320_bfloat16
+======================================================================
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+                                         torch_flash_ma         4.92%     259.759us        41.31%       2.182ms       2.182ms       0.000us         0.00%       3.825ms       3.825ms             1  
+                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.778ms       100.30%       3.778ms       3.778ms             1  
+                     aten::scaled_dot_product_attention         0.46%      24.480us         3.48%     183.685us      61.228us       0.000us         0.00%       2.990ms     996.566us             3  
+              aten::_scaled_dot_product_flash_attention         0.36%      18.972us         3.01%     159.205us      53.068us       0.000us         0.00%       2.990ms     996.566us             3  
+                         aten::_flash_attention_forward         0.75%      39.470us         2.21%     116.583us      38.861us       2.990ms        79.38%       2.990ms     996.566us             3  
+void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.990ms        79.38%       2.990ms     996.566us             3  
+                                       aten::contiguous         0.20%      10.370us        32.06%       1.693ms     141.118us       0.000us         0.00%     835.605us      69.634us            12  
+                                            aten::clone         0.56%      29.562us        31.86%       1.683ms     140.254us       0.000us         0.00%     835.605us      69.634us            12  
+                                            aten::copy_         1.55%      81.613us        30.00%       1.585ms     132.057us     776.758us        20.62%     835.605us      69.634us            12  
+void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     776.758us        20.62%     776.758us      64.730us            12  
+                                Activity Buffer Request        26.94%       1.423ms        26.94%       1.423ms       1.423ms      58.847us         1.56%      58.847us      58.847us             1  
+                                        aten::transpose         0.97%      51.460us         1.30%      68.660us       2.861us       0.000us         0.00%       0.000us       0.000us            24  
+                                       aten::as_strided         0.33%      17.200us         0.33%      17.200us       0.717us       0.000us         0.00%       0.000us       0.000us            24  
+                                       aten::empty_like         0.39%      20.693us         1.67%      88.333us       5.889us       0.000us         0.00%       0.000us       0.000us            15  
+                                            aten::empty         1.54%      81.451us         1.54%      81.451us       3.394us       0.000us         0.00%       0.000us       0.000us            24  
+                                       cudaLaunchKernel         1.97%     104.004us         1.97%     104.004us       6.934us       0.000us         0.00%       0.000us       0.000us            15  
+                                    aten::empty_strided         0.28%      14.530us         0.28%      14.530us       4.843us       0.000us         0.00%       0.000us       0.000us             3  
+                                 cudaDeviceGetAttribute         0.04%       1.902us         0.04%       1.902us       0.317us       0.000us         0.00%       0.000us       0.000us             6  
+                                   cudaFuncSetAttribute         0.07%       3.600us         0.07%       3.600us       1.200us       0.000us         0.00%       0.000us       0.000us             3  
+                                  cudaDeviceSynchronize        58.69%       3.100ms        58.69%       3.100ms       3.100ms       0.000us         0.00%       0.000us       0.000us             1  
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+Self CPU time total: 5.282ms
+Self CUDA time total: 3.766ms
+
+
+
+======================================================================
+PROFILE TRACE: torch_flash_ma | cuda_attn_L384_bfloat16
+======================================================================
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+                                         torch_flash_ma         4.63%     260.119us        43.14%       2.422ms       2.422ms       0.000us         0.00%       3.911ms       3.911ms             1  
+                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.865ms       100.31%       3.865ms       3.865ms             1  
+                     aten::scaled_dot_product_attention         0.43%      24.361us         3.22%     180.586us      60.195us       0.000us         0.00%       3.069ms       1.023ms             3  
+              aten::_scaled_dot_product_flash_attention         0.35%      19.401us         2.78%     156.225us      52.075us       0.000us         0.00%       3.069ms       1.023ms             3  
+                         aten::_flash_attention_forward         0.68%      38.111us         2.03%     114.053us      38.018us       3.069ms        79.64%       3.069ms       1.023ms             3  
+void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.069ms        79.64%       3.069ms       1.023ms             3  
+                                       aten::contiguous         0.17%       9.669us        34.46%       1.935ms     161.211us       0.000us         0.00%     842.147us      70.179us            12  
+                                            aten::clone         0.54%      30.453us        34.29%       1.925ms     160.405us       0.000us         0.00%     842.147us      70.179us            12  
+                                            aten::copy_         1.42%      79.471us        32.63%       1.832ms     152.656us     784.675us        20.36%     842.147us      70.179us            12  
+void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     784.675us        20.36%     784.675us      65.390us            12  
+                                Activity Buffer Request        26.20%       1.471ms        26.20%       1.471ms       1.471ms      57.472us         1.49%      57.472us      57.472us             1  
+                                        aten::transpose         0.92%      51.697us         1.23%      69.261us       2.886us       0.000us         0.00%       0.000us       0.000us            24  
+                                       aten::as_strided         0.31%      17.564us         0.31%      17.564us       0.732us       0.000us         0.00%       0.000us       0.000us            24  
+                                       aten::empty_like         0.36%      20.299us         1.45%      81.452us       5.430us       0.000us         0.00%       0.000us       0.000us            15  
+                                            aten::empty         1.34%      75.405us         1.34%      75.405us       3.142us       0.000us         0.00%       0.000us       0.000us            24  
+                                       cudaLaunchKernel         5.43%     304.654us         5.43%     304.654us      20.310us       0.000us         0.00%       0.000us       0.000us            15  
+                                    aten::empty_strided         0.25%      13.960us         0.25%      13.960us       4.653us       0.000us         0.00%       0.000us       0.000us             3  
+                                 cudaDeviceGetAttribute         0.03%       1.839us         0.03%       1.839us       0.306us       0.000us         0.00%       0.000us       0.000us             6  
+                                   cudaFuncSetAttribute         0.07%       3.750us         0.07%       3.750us       1.250us       0.000us         0.00%       0.000us       0.000us             3  
+                                  cudaDeviceSynchronize        56.86%       3.192ms        56.86%       3.192ms       3.192ms       0.000us         0.00%       0.000us       0.000us             1  
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+Self CPU time total: 5.614ms
+Self CUDA time total: 3.854ms
+
+
+
+======================================================================
+PROFILE TRACE: torch_flash_ma | cuda_attn_L448_bfloat16
+======================================================================
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+                                         torch_flash_ma         5.20%     312.192us        40.27%       2.420ms       2.420ms       0.000us         0.00%       4.370ms       4.370ms             1  
+                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.320ms       100.26%       4.320ms       4.320ms             1  
+                     aten::scaled_dot_product_attention         0.42%      25.401us         3.13%     188.317us      62.772us       0.000us         0.00%       3.499ms       1.166ms             3  
+              aten::_scaled_dot_product_flash_attention         0.34%      20.373us         2.71%     162.916us      54.305us       0.000us         0.00%       3.499ms       1.166ms             3  
+                         aten::_flash_attention_forward         0.70%      41.822us         1.99%     119.463us      39.821us       3.499ms        81.21%       3.499ms       1.166ms             3  
+void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.499ms        81.21%       3.499ms       1.166ms             3  
+                                       aten::contiguous         0.17%      10.061us        31.18%       1.873ms     156.120us       0.000us         0.00%     870.813us      72.568us            12  
+                                            aten::clone         0.51%      30.510us        31.01%       1.863ms     155.281us       0.000us         0.00%     870.813us      72.568us            12  
+                                            aten::copy_         1.32%      79.253us        29.46%       1.770ms     147.488us     809.726us        18.79%     870.813us      72.568us            12  
+void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     809.726us        18.79%     809.726us      67.477us            12  
+                                Activity Buffer Request        23.71%       1.425ms        23.71%       1.425ms       1.425ms      61.087us         1.42%      61.087us      61.087us             1  
+                                        aten::transpose         0.85%      51.371us         1.15%      68.940us       2.873us       0.000us         0.00%       0.000us       0.000us            24  
+                                       aten::as_strided         0.29%      17.569us         0.29%      17.569us       0.732us       0.000us         0.00%       0.000us       0.000us            24  
+                                       aten::empty_like         0.34%      20.420us         1.39%      83.415us       5.561us       0.000us         0.00%       0.000us       0.000us            15  
+                                            aten::empty         1.27%      76.235us         1.27%      76.235us       3.176us       0.000us         0.00%       0.000us       0.000us            24  
+                                       cudaLaunchKernel         4.81%     288.717us         4.81%     288.717us      19.248us       0.000us         0.00%       0.000us       0.000us            15  
+                                    aten::empty_strided         0.26%      15.360us         0.26%      15.360us       5.120us       0.000us         0.00%       0.000us       0.000us             3  
+                                 cudaDeviceGetAttribute         0.03%       1.980us         0.03%       1.980us       0.330us       0.000us         0.00%       0.000us       0.000us             6  
+                                   cudaFuncSetAttribute         0.06%       3.780us         0.06%       3.780us       1.260us       0.000us         0.00%       0.000us       0.000us             3  
+                                  cudaDeviceSynchronize        59.73%       3.589ms        59.73%       3.589ms       3.589ms       0.000us         0.00%       0.000us       0.000us             1  
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+Self CPU time total: 6.009ms
+Self CUDA time total: 4.309ms
+
+
+
+======================================================================
+PROFILE TRACE: torch_flash_ma | cuda_attn_L512_bfloat16
+======================================================================
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+                                         torch_flash_ma         4.62%     283.749us        39.30%       2.416ms       2.416ms       0.000us         0.00%       4.488ms       4.488ms             1  
+                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.437ms       100.26%       4.437ms       4.437ms             1  
+                     aten::scaled_dot_product_attention         0.41%      25.050us         2.99%     183.606us      61.202us       0.000us         0.00%       3.606ms       1.202ms             3  
+              aten::_scaled_dot_product_flash_attention         0.32%      19.512us         2.58%     158.556us      52.852us       0.000us         0.00%       3.606ms       1.202ms             3  
+                         aten::_flash_attention_forward         0.64%      39.583us         1.89%     116.223us      38.741us       3.606ms        81.47%       3.606ms       1.202ms             3  
+void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.606ms        81.47%       3.606ms       1.202ms             3  
+                                       aten::contiguous         0.16%       9.930us        30.93%       1.901ms     158.420us       0.000us         0.00%     882.206us      73.517us            12  
+                                            aten::clone         0.49%      30.220us        30.76%       1.891ms     157.592us       0.000us         0.00%     882.206us      73.517us            12  
+                                            aten::copy_         1.34%      82.326us        29.23%       1.797ms     149.726us     820.351us        18.53%     882.206us      73.517us            12  
+void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     820.351us        18.53%     820.351us      68.363us            12  
+                                Activity Buffer Request        23.42%       1.439ms        23.42%       1.439ms       1.439ms      61.855us         1.40%      61.855us      61.855us             1  
+                                        aten::transpose         0.85%      52.248us         1.14%      70.082us       2.920us       0.000us         0.00%       0.000us       0.000us            24  
+                                       aten::as_strided         0.29%      17.834us         0.29%      17.834us       0.743us       0.000us         0.00%       0.000us       0.000us            24  
+                                       aten::empty_like         0.33%      20.531us         1.36%      83.782us       5.585us       0.000us         0.00%       0.000us       0.000us            15  
+                                            aten::empty         1.26%      77.251us         1.26%      77.251us       3.219us       0.000us         0.00%       0.000us       0.000us            24  
+                                       cudaLaunchKernel         4.84%     297.592us         4.84%     297.592us      19.839us       0.000us         0.00%       0.000us       0.000us            15  
+                                    aten::empty_strided         0.24%      14.660us         0.24%      14.660us       4.887us       0.000us         0.00%       0.000us       0.000us             3  
+                                 cudaDeviceGetAttribute         0.03%       1.929us         0.03%       1.929us       0.321us       0.000us         0.00%       0.000us       0.000us             6  
+                                   cudaFuncSetAttribute         0.06%       3.839us         0.06%       3.839us       1.280us       0.000us         0.00%       0.000us       0.000us             3  
+                                  cudaDeviceSynchronize        60.70%       3.731ms        60.70%       3.731ms       3.731ms       0.000us         0.00%       0.000us       0.000us             1  
+-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
+Self CPU time total: 6.147ms
+Self CUDA time total: 4.426ms
+
+
+impl                     wl                  p50(ms)  ok
+torch_flash_ma           cuda_attn_L128_bfloat16     1.21  True
+torch_flash_ma           cuda_attn_L256_bfloat16     1.27  True
+torch_flash_ma           cuda_attn_L320_bfloat16     1.29  True
+torch_flash_ma           cuda_attn_L384_bfloat16     1.32  True
+torch_flash_ma           cuda_attn_L448_bfloat16     1.47  True
+torch_flash_ma           cuda_attn_L512_bfloat16     1.49  True
+
+
+

Artifacts:

+attention.jsonl +
+
+
+