Remove sources
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- README.md +6 -5
- build.toml +0 -593
- flake.lock +0 -168
- flake.nix +0 -51
- flash-attn/block.h +0 -94
- flash-attn/copy_sm90_bulk_reduce.hpp +0 -49
- flash-attn/cuda_check.h +0 -19
- flash-attn/epilogue_bwd.hpp +0 -523
- flash-attn/epilogue_fwd.hpp +0 -484
- flash-attn/flash.h +0 -220
- flash-attn/flash_api.cpp +0 -1623
- flash-attn/flash_bwd_kernel_sm80.h +0 -173
- flash-attn/flash_bwd_kernel_sm90.h +0 -282
- flash-attn/flash_bwd_launch_template.h +0 -377
- flash-attn/flash_bwd_postprocess_kernel.h +0 -256
- flash-attn/flash_bwd_preprocess_kernel.h +0 -252
- flash-attn/flash_fwd_combine.cu +0 -13
- flash-attn/flash_fwd_combine_kernel.h +0 -702
- flash-attn/flash_fwd_combine_launch_template.h +0 -88
- flash-attn/flash_fwd_kernel_sm80.h +0 -215
- flash-attn/flash_fwd_kernel_sm90.h +0 -468
- flash-attn/flash_fwd_launch_template.h +0 -231
- flash-attn/flash_prepare_scheduler.cu +0 -124
- flash-attn/heuristics.h +0 -65
- flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu +0 -6
- flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu +0 -6
- flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu +0 -6
- flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu +0 -6
- flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu +0 -6
- flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu +0 -18
    	
        README.md
    CHANGED
    
    | @@ -1,13 +1,14 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
             
            license: apache-2.0
         | 
| 3 | 
             
            tags:
         | 
| 4 | 
            -
            - kernel
         | 
| 5 | 
             
            ---
         | 
| 6 | 
            -
             | 
| 7 | 
             
            # vllm-flash-attn3
         | 
| 8 |  | 
| 9 | 
             
            This is an implementation of Flash Attention 3 CUDA kernels with support for attention sinks. The attention sinks implementation was contributed to Flash Attention by the [vLLM team](https://huggingface.co/vllm-project). The [transformers team](https://huggingface.co/transformers-community) packaged the implementation and pre-built it for use with the [kernels library](https://github.com/huggingface/kernels).
         | 
| 10 |  | 
|  | |
| 11 |  | 
| 12 | 
             
            ## Quickstart
         | 
| 13 |  | 
| @@ -43,7 +44,7 @@ torch.cuda.manual_seed(42) | |
| 43 | 
             
            # Parameters
         | 
| 44 | 
             
            batch_size = 2
         | 
| 45 | 
             
            seqlen_q = 128  # Query sequence length
         | 
| 46 | 
            -
            seqlen_k = 256  # Key sequence length | 
| 47 | 
             
            nheads = 8      # Number of attention heads
         | 
| 48 | 
             
            d = 64          # Head dimension
         | 
| 49 |  | 
| @@ -65,7 +66,6 @@ print(f"\nAttention computation successful!") | |
| 65 | 
             
            print(f"Output tensor stats - Mean: {output.mean().item():.4f}, Std: {output.std().item():.4f}")
         | 
| 66 | 
             
            ```
         | 
| 67 |  | 
| 68 | 
            -
             | 
| 69 | 
             
            ## How to Use
         | 
| 70 |  | 
| 71 | 
             
            When loading your model with transformers, provide this repository id as the source of the attention implementation:
         | 
| @@ -91,4 +91,5 @@ This will automatically resolve and download the appropriate code for your archi | |
| 91 |  | 
| 92 | 
             
            - [Tri Dao](https://huggingface.co/tridao) and team for Flash Attention and [Flash Attention 3](https://tridao.me/blog/2024/flash3/).
         | 
| 93 | 
             
            - The [vLLM team](https://huggingface.co/vllm-project) for their implementation and their contribution of attention sinks.
         | 
| 94 | 
            -
            - The [transformers team](https://huggingface.co/transformers-community) for packaging, testing, building and making it available for use with the [kernels library](https://github.com/huggingface/kernels).
         | 
|  | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
             
            license: apache-2.0
         | 
| 3 | 
             
            tags:
         | 
| 4 | 
            +
              - kernel
         | 
| 5 | 
             
            ---
         | 
| 6 | 
            +
             | 
| 7 | 
             
            # vllm-flash-attn3
         | 
| 8 |  | 
| 9 | 
             
            This is an implementation of Flash Attention 3 CUDA kernels with support for attention sinks. The attention sinks implementation was contributed to Flash Attention by the [vLLM team](https://huggingface.co/vllm-project). The [transformers team](https://huggingface.co/transformers-community) packaged the implementation and pre-built it for use with the [kernels library](https://github.com/huggingface/kernels).
         | 
| 10 |  | 
| 11 | 
            +
            Kernel source: https://github.com/huggingface/kernels-community/tree/main/vllm-flash-attn3
         | 
| 12 |  | 
| 13 | 
             
            ## Quickstart
         | 
| 14 |  | 
|  | |
| 44 | 
             
            # Parameters
         | 
| 45 | 
             
            batch_size = 2
         | 
| 46 | 
             
            seqlen_q = 128  # Query sequence length
         | 
| 47 | 
            +
            seqlen_k = 256  # Key sequence length
         | 
| 48 | 
             
            nheads = 8      # Number of attention heads
         | 
| 49 | 
             
            d = 64          # Head dimension
         | 
| 50 |  | 
|  | |
| 66 | 
             
            print(f"Output tensor stats - Mean: {output.mean().item():.4f}, Std: {output.std().item():.4f}")
         | 
| 67 | 
             
            ```
         | 
| 68 |  | 
|  | |
| 69 | 
             
            ## How to Use
         | 
| 70 |  | 
| 71 | 
             
            When loading your model with transformers, provide this repository id as the source of the attention implementation:
         | 
|  | |
| 91 |  | 
| 92 | 
             
            - [Tri Dao](https://huggingface.co/tridao) and team for Flash Attention and [Flash Attention 3](https://tridao.me/blog/2024/flash3/).
         | 
| 93 | 
             
            - The [vLLM team](https://huggingface.co/vllm-project) for their implementation and their contribution of attention sinks.
         | 
| 94 | 
            +
            - The [transformers team](https://huggingface.co/transformers-community) for packaging, testing, building and making it available for use with the [kernels library](https://github.com/huggingface/kernels).
         | 
| 95 | 
            +
             | 
    	
        build.toml
    DELETED
    
    | @@ -1,593 +0,0 @@ | |
| 1 | 
            -
            [general]
         | 
| 2 | 
            -
            name = "vllm_flash_attn3"
         | 
| 3 | 
            -
            universal = false
         | 
| 4 | 
            -
            cuda-minver = "12.4"
         | 
| 5 | 
            -
            cuda-maxver = "12.4"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            [torch]
         | 
| 8 | 
            -
            src = [
         | 
| 9 | 
            -
              "torch-ext/pytorch_shim.h",
         | 
| 10 | 
            -
              "torch-ext/torch_binding.cpp",
         | 
| 11 | 
            -
              "torch-ext/torch_binding.h",
         | 
| 12 | 
            -
            ]
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            [kernel.flash_attn]
         | 
| 15 | 
            -
            backend = "cuda"
         | 
| 16 | 
            -
            cuda-capabilities = ["8.0", "9.0a"]
         | 
| 17 | 
            -
            cuda-flags = [
         | 
| 18 | 
            -
              "-O3",
         | 
| 19 | 
            -
              "-std=c++17",
         | 
| 20 | 
            -
              "--ftemplate-backtrace-limit=0",              # To debug template code
         | 
| 21 | 
            -
              "--use_fast_math",
         | 
| 22 | 
            -
              "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
         | 
| 23 | 
            -
              "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
         | 
| 24 | 
            -
              "-DCUTLASS_ENABLE_GDC_FOR_SM90",
         | 
| 25 | 
            -
              "--expt-relaxed-constexpr",
         | 
| 26 | 
            -
              "--expt-extended-lambda",
         | 
| 27 | 
            -
              "--use_fast_math",
         | 
| 28 | 
            -
              "-DNDEBUG",
         | 
| 29 | 
            -
            ]
         | 
| 30 | 
            -
            cxx-flags = ["-DFLASHATTENTION_DISABLE_PYBIND"]
         | 
| 31 | 
            -
            src = [
         | 
| 32 | 
            -
              "flash-attn/cuda_check.h",
         | 
| 33 | 
            -
              "flash-attn/flash_api.cpp",
         | 
| 34 | 
            -
              "flash-attn/flash_fwd_combine.cu",
         | 
| 35 | 
            -
              "flash-attn/flash_fwd_combine_kernel.h",
         | 
| 36 | 
            -
              "flash-attn/flash_fwd_combine_launch_template.h",
         | 
| 37 | 
            -
              "flash-attn/flash.h",
         | 
| 38 | 
            -
              "flash-attn/flash_prepare_scheduler.cu",
         | 
| 39 | 
            -
              "flash-attn/heuristics.h",
         | 
| 40 | 
            -
              "flash-attn/seqlen.h",
         | 
| 41 | 
            -
              "flash-attn/static_switch.h",
         | 
| 42 | 
            -
              "flash-attn/tile_size.h",
         | 
| 43 | 
            -
              "flash-attn/utils.h",
         | 
| 44 | 
            -
            ]
         | 
| 45 | 
            -
            depends = ["torch", "cutlass_3_9"]
         | 
| 46 | 
            -
             | 
| 47 | 
            -
            [kernel.flash_attn_sm80]
         | 
| 48 | 
            -
            backend = "cuda"
         | 
| 49 | 
            -
            cuda-capabilities = ["8.0", "9.0a"]
         | 
| 50 | 
            -
            cuda-flags = [
         | 
| 51 | 
            -
              "-O3",
         | 
| 52 | 
            -
              "-std=c++17",
         | 
| 53 | 
            -
              "--ftemplate-backtrace-limit=0",              # To debug template code
         | 
| 54 | 
            -
              "--use_fast_math",
         | 
| 55 | 
            -
              "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
         | 
| 56 | 
            -
              "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
         | 
| 57 | 
            -
              "-DCUTLASS_ENABLE_GDC_FOR_SM90",
         | 
| 58 | 
            -
              "--expt-relaxed-constexpr",
         | 
| 59 | 
            -
              "--expt-extended-lambda",
         | 
| 60 | 
            -
              "--use_fast_math",
         | 
| 61 | 
            -
              "-DNDEBUG",
         | 
| 62 | 
            -
            ]
         | 
| 63 | 
            -
            src = [
         | 
| 64 | 
            -
              "flash-attn/block.h",
         | 
| 65 | 
            -
              "flash-attn/copy_sm90_bulk_reduce.hpp",
         | 
| 66 | 
            -
              "flash-attn/epilogue_bwd.hpp",
         | 
| 67 | 
            -
              "flash-attn/epilogue_fwd.hpp",
         | 
| 68 | 
            -
              "flash-attn/flash.h",
         | 
| 69 | 
            -
              "flash-attn/flash_bwd_kernel_sm80.h",
         | 
| 70 | 
            -
              "flash-attn/flash_bwd_kernel_sm90.h",
         | 
| 71 | 
            -
              "flash-attn/flash_bwd_launch_template.h",
         | 
| 72 | 
            -
              "flash-attn/flash_bwd_postprocess_kernel.h",
         | 
| 73 | 
            -
              "flash-attn/flash_bwd_preprocess_kernel.h",
         | 
| 74 | 
            -
              "flash-attn/flash_fwd_launch_template.h",
         | 
| 75 | 
            -
              "flash-attn/flash_fwd_kernel_sm80.h",
         | 
| 76 | 
            -
              "flash-attn/flash_fwd_kernel_sm90.h",
         | 
| 77 | 
            -
              "flash-attn/heuristics.h",
         | 
| 78 | 
            -
              "flash-attn/mainloop_bwd_sm80.hpp",
         | 
| 79 | 
            -
              "flash-attn/mainloop_fwd_sm80.hpp",
         | 
| 80 | 
            -
              "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
         | 
| 81 | 
            -
              "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
         | 
| 82 | 
            -
              "flash-attn/mask.h",
         | 
| 83 | 
            -
              "flash-attn/named_barrier.hpp",
         | 
| 84 | 
            -
              "flash-attn/pack_gqa.h",
         | 
| 85 | 
            -
              "flash-attn/paged_kv.h",
         | 
| 86 | 
            -
              "flash-attn/rotary.h",
         | 
| 87 | 
            -
              "flash-attn/sm90_pipeline_no_cluster.hpp",
         | 
| 88 | 
            -
              "flash-attn/softmax.h",
         | 
| 89 | 
            -
              "flash-attn/tile_size.h",
         | 
| 90 | 
            -
              "flash-attn/tile_scheduler.hpp",
         | 
| 91 | 
            -
             | 
| 92 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu",
         | 
| 93 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu",
         | 
| 94 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu",
         | 
| 95 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu",
         | 
| 96 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu",
         | 
| 97 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu",
         | 
| 98 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu",
         | 
| 99 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu",
         | 
| 100 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu",
         | 
| 101 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu",
         | 
| 102 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu",
         | 
| 103 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu",
         | 
| 104 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu",
         | 
| 105 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu",
         | 
| 106 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu",
         | 
| 107 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu",
         | 
| 108 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu",
         | 
| 109 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu",
         | 
| 110 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu",
         | 
| 111 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu",
         | 
| 112 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu",
         | 
| 113 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu",
         | 
| 114 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu",
         | 
| 115 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu",
         | 
| 116 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu",
         | 
| 117 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu",
         | 
| 118 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu",
         | 
| 119 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu",
         | 
| 120 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu",
         | 
| 121 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu",
         | 
| 122 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu",
         | 
| 123 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu",
         | 
| 124 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu",
         | 
| 125 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu",
         | 
| 126 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu",
         | 
| 127 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu",
         | 
| 128 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu",
         | 
| 129 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu",
         | 
| 130 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu",
         | 
| 131 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu",
         | 
| 132 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu",
         | 
| 133 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu",
         | 
| 134 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu",
         | 
| 135 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu",
         | 
| 136 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu",
         | 
| 137 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu",
         | 
| 138 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu",
         | 
| 139 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu",
         | 
| 140 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu",
         | 
| 141 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu",
         | 
| 142 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu",
         | 
| 143 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu",
         | 
| 144 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu",
         | 
| 145 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu",
         | 
| 146 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu",
         | 
| 147 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu",
         | 
| 148 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu",
         | 
| 149 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu",
         | 
| 150 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu",
         | 
| 151 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu",
         | 
| 152 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu",
         | 
| 153 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu",
         | 
| 154 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu",
         | 
| 155 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu",
         | 
| 156 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu",
         | 
| 157 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu",
         | 
| 158 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu",
         | 
| 159 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu",
         | 
| 160 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu",
         | 
| 161 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu",
         | 
| 162 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu",
         | 
| 163 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu",
         | 
| 164 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu",
         | 
| 165 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu",
         | 
| 166 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu",
         | 
| 167 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu",
         | 
| 168 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu",
         | 
| 169 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu",
         | 
| 170 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu",
         | 
| 171 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu",
         | 
| 172 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu",
         | 
| 173 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu",
         | 
| 174 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu",
         | 
| 175 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu",
         | 
| 176 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu",
         | 
| 177 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu",
         | 
| 178 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu",
         | 
| 179 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu",
         | 
| 180 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu",
         | 
| 181 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu",
         | 
| 182 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu",
         | 
| 183 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu",
         | 
| 184 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu",
         | 
| 185 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu",
         | 
| 186 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu",
         | 
| 187 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu",
         | 
| 188 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu",
         | 
| 189 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu",
         | 
| 190 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu",
         | 
| 191 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu"
         | 
| 192 | 
            -
            ]
         | 
| 193 | 
            -
            include = ["flash-attn"]
         | 
| 194 | 
            -
            depends = ["torch", "cutlass_3_9"]
         | 
| 195 | 
            -
             | 
| 196 | 
            -
            [kernel.flash_attn_sm90]
         | 
| 197 | 
            -
            backend = "cuda"
         | 
| 198 | 
            -
            cuda-capabilities = ["8.0", "9.0a"]
         | 
| 199 | 
            -
            cuda-flags = [
         | 
| 200 | 
            -
              "-O3",
         | 
| 201 | 
            -
              "-std=c++17",
         | 
| 202 | 
            -
              "--ftemplate-backtrace-limit=0",              # To debug template code
         | 
| 203 | 
            -
              "--use_fast_math",
         | 
| 204 | 
            -
              "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
         | 
| 205 | 
            -
              "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
         | 
| 206 | 
            -
              "-DCUTLASS_ENABLE_GDC_FOR_SM90",
         | 
| 207 | 
            -
              "--expt-relaxed-constexpr",
         | 
| 208 | 
            -
              "--expt-extended-lambda",
         | 
| 209 | 
            -
              "--use_fast_math",
         | 
| 210 | 
            -
              "-DNDEBUG",
         | 
| 211 | 
            -
            ]
         | 
| 212 | 
            -
            src = [
         | 
| 213 | 
            -
              "flash-attn/block.h",
         | 
| 214 | 
            -
              "flash-attn/copy_sm90_bulk_reduce.hpp",
         | 
| 215 | 
            -
              "flash-attn/epilogue_bwd.hpp",
         | 
| 216 | 
            -
              "flash-attn/epilogue_fwd.hpp",
         | 
| 217 | 
            -
              "flash-attn/flash.h",
         | 
| 218 | 
            -
              "flash-attn/flash_bwd_kernel_sm80.h",
         | 
| 219 | 
            -
              "flash-attn/flash_bwd_kernel_sm90.h",
         | 
| 220 | 
            -
              "flash-attn/flash_bwd_launch_template.h",
         | 
| 221 | 
            -
              "flash-attn/flash_bwd_postprocess_kernel.h",
         | 
| 222 | 
            -
              "flash-attn/flash_bwd_preprocess_kernel.h",
         | 
| 223 | 
            -
              "flash-attn/flash_fwd_launch_template.h",
         | 
| 224 | 
            -
              "flash-attn/flash_fwd_kernel_sm80.h",
         | 
| 225 | 
            -
              "flash-attn/flash_fwd_kernel_sm90.h",
         | 
| 226 | 
            -
              "flash-attn/heuristics.h",
         | 
| 227 | 
            -
              "flash-attn/mainloop_bwd_sm80.hpp",
         | 
| 228 | 
            -
              "flash-attn/mainloop_fwd_sm80.hpp",
         | 
| 229 | 
            -
              "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
         | 
| 230 | 
            -
              "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
         | 
| 231 | 
            -
              "flash-attn/mask.h",
         | 
| 232 | 
            -
              "flash-attn/named_barrier.hpp",
         | 
| 233 | 
            -
              "flash-attn/pack_gqa.h",
         | 
| 234 | 
            -
              "flash-attn/paged_kv.h",
         | 
| 235 | 
            -
              "flash-attn/rotary.h",
         | 
| 236 | 
            -
              "flash-attn/sm90_pipeline_no_cluster.hpp",
         | 
| 237 | 
            -
              "flash-attn/softmax.h",
         | 
| 238 | 
            -
              "flash-attn/tile_size.h",
         | 
| 239 | 
            -
              "flash-attn/tile_scheduler.hpp",
         | 
| 240 | 
            -
             | 
| 241 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu",
         | 
| 242 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu",
         | 
| 243 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu",
         | 
| 244 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu",
         | 
| 245 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu",
         | 
| 246 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu",
         | 
| 247 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu",
         | 
| 248 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu",
         | 
| 249 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu",
         | 
| 250 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu",
         | 
| 251 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu",
         | 
| 252 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu",
         | 
| 253 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu",
         | 
| 254 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu",
         | 
| 255 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu",
         | 
| 256 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu",
         | 
| 257 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu",
         | 
| 258 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu",
         | 
| 259 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu",
         | 
| 260 | 
            -
              "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu",
         | 
| 261 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu",
         | 
| 262 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu",
         | 
| 263 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu",
         | 
| 264 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu",
         | 
| 265 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu",
         | 
| 266 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu",
         | 
| 267 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu",
         | 
| 268 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu",
         | 
| 269 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu",
         | 
| 270 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu",
         | 
| 271 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu",
         | 
| 272 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu",
         | 
| 273 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu",
         | 
| 274 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu",
         | 
| 275 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu",
         | 
| 276 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu",
         | 
| 277 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu",
         | 
| 278 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu",
         | 
| 279 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu",
         | 
| 280 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu",
         | 
| 281 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu",
         | 
| 282 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu",
         | 
| 283 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu",
         | 
| 284 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu",
         | 
| 285 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu",
         | 
| 286 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu",
         | 
| 287 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu",
         | 
| 288 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu",
         | 
| 289 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu",
         | 
| 290 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu",
         | 
| 291 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu",
         | 
| 292 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu",
         | 
| 293 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu",
         | 
| 294 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu",
         | 
| 295 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu",
         | 
| 296 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu",
         | 
| 297 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu",
         | 
| 298 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu",
         | 
| 299 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu",
         | 
| 300 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu",
         | 
| 301 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu",
         | 
| 302 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu",
         | 
| 303 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu",
         | 
| 304 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu",
         | 
| 305 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu",
         | 
| 306 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu",
         | 
| 307 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu",
         | 
| 308 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu",
         | 
| 309 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu",
         | 
| 310 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu",
         | 
| 311 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu",
         | 
| 312 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu",
         | 
| 313 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu",
         | 
| 314 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu",
         | 
| 315 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu",
         | 
| 316 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu",
         | 
| 317 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu",
         | 
| 318 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu",
         | 
| 319 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu",
         | 
| 320 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu",
         | 
| 321 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu",
         | 
| 322 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu",
         | 
| 323 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu",
         | 
| 324 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu",
         | 
| 325 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu",
         | 
| 326 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu",
         | 
| 327 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu",
         | 
| 328 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu",
         | 
| 329 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu",
         | 
| 330 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu",
         | 
| 331 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu",
         | 
| 332 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu",
         | 
| 333 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu",
         | 
| 334 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu",
         | 
| 335 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu",
         | 
| 336 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu",
         | 
| 337 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu",
         | 
| 338 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu",
         | 
| 339 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu",
         | 
| 340 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu",
         | 
| 341 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu",
         | 
| 342 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu",
         | 
| 343 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu",
         | 
| 344 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu",
         | 
| 345 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu",
         | 
| 346 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu",
         | 
| 347 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu",
         | 
| 348 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu",
         | 
| 349 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu",
         | 
| 350 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu",
         | 
| 351 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu",
         | 
| 352 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu",
         | 
| 353 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu",
         | 
| 354 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu",
         | 
| 355 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu",
         | 
| 356 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu",
         | 
| 357 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu",
         | 
| 358 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu",
         | 
| 359 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu",
         | 
| 360 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu",
         | 
| 361 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu",
         | 
| 362 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu",
         | 
| 363 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu",
         | 
| 364 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu",
         | 
| 365 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu",
         | 
| 366 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu",
         | 
| 367 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu",
         | 
| 368 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu",
         | 
| 369 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu",
         | 
| 370 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu",
         | 
| 371 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu",
         | 
| 372 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu",
         | 
| 373 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu",
         | 
| 374 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu",
         | 
| 375 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu",
         | 
| 376 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu",
         | 
| 377 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu",
         | 
| 378 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu",
         | 
| 379 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu",
         | 
| 380 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu",
         | 
| 381 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu",
         | 
| 382 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu",
         | 
| 383 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu",
         | 
| 384 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu",
         | 
| 385 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu",
         | 
| 386 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu",
         | 
| 387 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu",
         | 
| 388 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu",
         | 
| 389 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu",
         | 
| 390 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu",
         | 
| 391 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu",
         | 
| 392 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu",
         | 
| 393 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu",
         | 
| 394 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu",
         | 
| 395 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu",
         | 
| 396 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu",
         | 
| 397 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu",
         | 
| 398 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu",
         | 
| 399 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu",
         | 
| 400 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu",
         | 
| 401 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu",
         | 
| 402 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu",
         | 
| 403 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu",
         | 
| 404 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu",
         | 
| 405 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu",
         | 
| 406 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu",
         | 
| 407 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu",
         | 
| 408 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu",
         | 
| 409 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu",
         | 
| 410 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu",
         | 
| 411 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu",
         | 
| 412 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu",
         | 
| 413 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu",
         | 
| 414 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu",
         | 
| 415 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu",
         | 
| 416 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu",
         | 
| 417 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu",
         | 
| 418 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu",
         | 
| 419 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu",
         | 
| 420 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu",
         | 
| 421 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu",
         | 
| 422 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu",
         | 
| 423 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu",
         | 
| 424 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu",
         | 
| 425 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu",
         | 
| 426 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu",
         | 
| 427 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu",
         | 
| 428 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu",
         | 
| 429 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu",
         | 
| 430 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu",
         | 
| 431 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu",
         | 
| 432 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu",
         | 
| 433 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu",
         | 
| 434 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu",
         | 
| 435 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu",
         | 
| 436 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu",
         | 
| 437 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu",
         | 
| 438 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu",
         | 
| 439 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu",
         | 
| 440 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu",
         | 
| 441 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu",
         | 
| 442 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu",
         | 
| 443 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu",
         | 
| 444 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu",
         | 
| 445 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu",
         | 
| 446 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu",
         | 
| 447 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu",
         | 
| 448 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu",
         | 
| 449 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu",
         | 
| 450 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu",
         | 
| 451 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu",
         | 
| 452 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu",
         | 
| 453 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu",
         | 
| 454 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu",
         | 
| 455 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu",
         | 
| 456 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu",
         | 
| 457 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu",
         | 
| 458 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu",
         | 
| 459 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu",
         | 
| 460 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu",
         | 
| 461 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu",
         | 
| 462 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu",
         | 
| 463 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu",
         | 
| 464 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu",
         | 
| 465 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu",
         | 
| 466 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu",
         | 
| 467 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu",
         | 
| 468 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu",
         | 
| 469 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu",
         | 
| 470 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu",
         | 
| 471 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu",
         | 
| 472 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu",
         | 
| 473 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu",
         | 
| 474 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu",
         | 
| 475 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu",
         | 
| 476 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu",
         | 
| 477 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu",
         | 
| 478 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu",
         | 
| 479 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu",
         | 
| 480 | 
            -
              "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu",
         | 
| 481 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu",
         | 
| 482 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu",
         | 
| 483 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu",
         | 
| 484 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu",
         | 
| 485 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu",
         | 
| 486 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu",
         | 
| 487 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu",
         | 
| 488 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu",
         | 
| 489 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu",
         | 
| 490 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu",
         | 
| 491 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu",
         | 
| 492 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu",
         | 
| 493 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu",
         | 
| 494 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu",
         | 
| 495 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu",
         | 
| 496 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu",
         | 
| 497 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu",
         | 
| 498 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu",
         | 
| 499 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu",
         | 
| 500 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu",
         | 
| 501 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu",
         | 
| 502 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu",
         | 
| 503 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu",
         | 
| 504 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu",
         | 
| 505 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu",
         | 
| 506 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu",
         | 
| 507 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu",
         | 
| 508 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu",
         | 
| 509 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu",
         | 
| 510 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu",
         | 
| 511 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu",
         | 
| 512 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu",
         | 
| 513 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu",
         | 
| 514 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu",
         | 
| 515 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu",
         | 
| 516 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu",
         | 
| 517 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu",
         | 
| 518 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu",
         | 
| 519 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu",
         | 
| 520 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu",
         | 
| 521 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu",
         | 
| 522 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu",
         | 
| 523 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu",
         | 
| 524 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu",
         | 
| 525 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu",
         | 
| 526 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu",
         | 
| 527 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu",
         | 
| 528 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu",
         | 
| 529 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu",
         | 
| 530 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu",
         | 
| 531 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu",
         | 
| 532 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu",
         | 
| 533 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu",
         | 
| 534 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu",
         | 
| 535 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu",
         | 
| 536 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu",
         | 
| 537 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu",
         | 
| 538 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu",
         | 
| 539 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu",
         | 
| 540 | 
            -
              "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu",
         | 
| 541 | 
            -
            ]
         | 
| 542 | 
            -
            include = ["flash-attn"]
         | 
| 543 | 
            -
            depends = ["torch", "cutlass_3_9"]
         | 
| 544 | 
            -
             | 
| 545 | 
            -
            # [kernel.flash_attn_sm100]
         | 
| 546 | 
            -
            # backend = "cuda"
         | 
| 547 | 
            -
            # cuda-capabilities = ["8.0", "9.0a", "10.0"]
         | 
| 548 | 
            -
            # cuda-flags = [
         | 
| 549 | 
            -
            #   "-O3",
         | 
| 550 | 
            -
            #   "-std=c++17",
         | 
| 551 | 
            -
            #   "--ftemplate-backtrace-limit=0",              # To debug template code
         | 
| 552 | 
            -
            #   "--use_fast_math",
         | 
| 553 | 
            -
            #   "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
         | 
| 554 | 
            -
            #   "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
         | 
| 555 | 
            -
            #   "-DCUTLASS_ENABLE_GDC_FOR_SM90",
         | 
| 556 | 
            -
            #   "--expt-relaxed-constexpr",
         | 
| 557 | 
            -
            #   "--expt-extended-lambda",
         | 
| 558 | 
            -
            #   "--use_fast_math",
         | 
| 559 | 
            -
            #   "-DNDEBUG",
         | 
| 560 | 
            -
            # ]
         | 
| 561 | 
            -
            # src = [
         | 
| 562 | 
            -
            #   "flash-attn/block.h",
         | 
| 563 | 
            -
            #   "flash-attn/copy_sm90_bulk_reduce.hpp",
         | 
| 564 | 
            -
            #   "flash-attn/epilogue_bwd.hpp",
         | 
| 565 | 
            -
            #   "flash-attn/epilogue_fwd.hpp",
         | 
| 566 | 
            -
            #   "flash-attn/flash.h",
         | 
| 567 | 
            -
            #   "flash-attn/flash_bwd_kernel_sm80.h",
         | 
| 568 | 
            -
            #   "flash-attn/flash_bwd_kernel_sm90.h",
         | 
| 569 | 
            -
            #   "flash-attn/flash_bwd_launch_template.h",
         | 
| 570 | 
            -
            #   "flash-attn/flash_bwd_postprocess_kernel.h",
         | 
| 571 | 
            -
            #   "flash-attn/flash_bwd_preprocess_kernel.h",
         | 
| 572 | 
            -
            #   "flash-attn/flash_fwd_launch_template.h",
         | 
| 573 | 
            -
            #   "flash-attn/flash_fwd_kernel_sm80.h",
         | 
| 574 | 
            -
            #   "flash-attn/flash_fwd_kernel_sm90.h",
         | 
| 575 | 
            -
            #   "flash-attn/heuristics.h",
         | 
| 576 | 
            -
            #   "flash-attn/mainloop_bwd_sm80.hpp",
         | 
| 577 | 
            -
            #   "flash-attn/mainloop_fwd_sm80.hpp",
         | 
| 578 | 
            -
            #   "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
         | 
| 579 | 
            -
            #   "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
         | 
| 580 | 
            -
            #   "flash-attn/mask.h",
         | 
| 581 | 
            -
            #   "flash-attn/named_barrier.hpp",
         | 
| 582 | 
            -
            #   "flash-attn/pack_gqa.h",
         | 
| 583 | 
            -
            #   "flash-attn/paged_kv.h",
         | 
| 584 | 
            -
            #   "flash-attn/rotary.h",
         | 
| 585 | 
            -
            #   "flash-attn/sm90_pipeline_no_cluster.hpp",
         | 
| 586 | 
            -
            #   "flash-attn/softmax.h",
         | 
| 587 | 
            -
            #   "flash-attn/tile_size.h",
         | 
| 588 | 
            -
            #   "flash-attn/tile_scheduler.hpp",
         | 
| 589 | 
            -
            #
         | 
| 590 | 
            -
            #   "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu",
         | 
| 591 | 
            -
            # ]
         | 
| 592 | 
            -
            # include = ["flash-attn"]
         | 
| 593 | 
            -
            # depends = ["torch", "cutlass_3_9"]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flake.lock
    DELETED
    
    | @@ -1,168 +0,0 @@ | |
| 1 | 
            -
            {
         | 
| 2 | 
            -
              "nodes": {
         | 
| 3 | 
            -
                "flake-compat": {
         | 
| 4 | 
            -
                  "locked": {
         | 
| 5 | 
            -
                    "lastModified": 1747046372,
         | 
| 6 | 
            -
                    "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
         | 
| 7 | 
            -
                    "owner": "edolstra",
         | 
| 8 | 
            -
                    "repo": "flake-compat",
         | 
| 9 | 
            -
                    "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
         | 
| 10 | 
            -
                    "type": "github"
         | 
| 11 | 
            -
                  },
         | 
| 12 | 
            -
                  "original": {
         | 
| 13 | 
            -
                    "owner": "edolstra",
         | 
| 14 | 
            -
                    "repo": "flake-compat",
         | 
| 15 | 
            -
                    "type": "github"
         | 
| 16 | 
            -
                  }
         | 
| 17 | 
            -
                },
         | 
| 18 | 
            -
                "flake-compat_2": {
         | 
| 19 | 
            -
                  "locked": {
         | 
| 20 | 
            -
                    "lastModified": 1733328505,
         | 
| 21 | 
            -
                    "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
         | 
| 22 | 
            -
                    "owner": "edolstra",
         | 
| 23 | 
            -
                    "repo": "flake-compat",
         | 
| 24 | 
            -
                    "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
         | 
| 25 | 
            -
                    "type": "github"
         | 
| 26 | 
            -
                  },
         | 
| 27 | 
            -
                  "original": {
         | 
| 28 | 
            -
                    "owner": "edolstra",
         | 
| 29 | 
            -
                    "repo": "flake-compat",
         | 
| 30 | 
            -
                    "type": "github"
         | 
| 31 | 
            -
                  }
         | 
| 32 | 
            -
                },
         | 
| 33 | 
            -
                "flake-utils": {
         | 
| 34 | 
            -
                  "inputs": {
         | 
| 35 | 
            -
                    "systems": "systems"
         | 
| 36 | 
            -
                  },
         | 
| 37 | 
            -
                  "locked": {
         | 
| 38 | 
            -
                    "lastModified": 1731533236,
         | 
| 39 | 
            -
                    "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
         | 
| 40 | 
            -
                    "owner": "numtide",
         | 
| 41 | 
            -
                    "repo": "flake-utils",
         | 
| 42 | 
            -
                    "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
         | 
| 43 | 
            -
                    "type": "github"
         | 
| 44 | 
            -
                  },
         | 
| 45 | 
            -
                  "original": {
         | 
| 46 | 
            -
                    "owner": "numtide",
         | 
| 47 | 
            -
                    "repo": "flake-utils",
         | 
| 48 | 
            -
                    "type": "github"
         | 
| 49 | 
            -
                  }
         | 
| 50 | 
            -
                },
         | 
| 51 | 
            -
                "flake-utils_2": {
         | 
| 52 | 
            -
                  "inputs": {
         | 
| 53 | 
            -
                    "systems": "systems_2"
         | 
| 54 | 
            -
                  },
         | 
| 55 | 
            -
                  "locked": {
         | 
| 56 | 
            -
                    "lastModified": 1731533236,
         | 
| 57 | 
            -
                    "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
         | 
| 58 | 
            -
                    "owner": "numtide",
         | 
| 59 | 
            -
                    "repo": "flake-utils",
         | 
| 60 | 
            -
                    "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
         | 
| 61 | 
            -
                    "type": "github"
         | 
| 62 | 
            -
                  },
         | 
| 63 | 
            -
                  "original": {
         | 
| 64 | 
            -
                    "owner": "numtide",
         | 
| 65 | 
            -
                    "repo": "flake-utils",
         | 
| 66 | 
            -
                    "type": "github"
         | 
| 67 | 
            -
                  }
         | 
| 68 | 
            -
                },
         | 
| 69 | 
            -
                "hf-nix": {
         | 
| 70 | 
            -
                  "inputs": {
         | 
| 71 | 
            -
                    "flake-compat": "flake-compat_2",
         | 
| 72 | 
            -
                    "flake-utils": "flake-utils_2",
         | 
| 73 | 
            -
                    "nixpkgs": "nixpkgs"
         | 
| 74 | 
            -
                  },
         | 
| 75 | 
            -
                  "locked": {
         | 
| 76 | 
            -
                    "lastModified": 1750234878,
         | 
| 77 | 
            -
                    "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
         | 
| 78 | 
            -
                    "owner": "huggingface",
         | 
| 79 | 
            -
                    "repo": "hf-nix",
         | 
| 80 | 
            -
                    "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
         | 
| 81 | 
            -
                    "type": "github"
         | 
| 82 | 
            -
                  },
         | 
| 83 | 
            -
                  "original": {
         | 
| 84 | 
            -
                    "owner": "huggingface",
         | 
| 85 | 
            -
                    "repo": "hf-nix",
         | 
| 86 | 
            -
                    "type": "github"
         | 
| 87 | 
            -
                  }
         | 
| 88 | 
            -
                },
         | 
| 89 | 
            -
                "kernel-builder": {
         | 
| 90 | 
            -
                  "inputs": {
         | 
| 91 | 
            -
                    "flake-compat": "flake-compat",
         | 
| 92 | 
            -
                    "flake-utils": "flake-utils",
         | 
| 93 | 
            -
                    "hf-nix": "hf-nix",
         | 
| 94 | 
            -
                    "nixpkgs": [
         | 
| 95 | 
            -
                      "kernel-builder",
         | 
| 96 | 
            -
                      "hf-nix",
         | 
| 97 | 
            -
                      "nixpkgs"
         | 
| 98 | 
            -
                    ]
         | 
| 99 | 
            -
                  },
         | 
| 100 | 
            -
                  "locked": {
         | 
| 101 | 
            -
                    "lastModified": 1751014803,
         | 
| 102 | 
            -
                    "narHash": "sha256-9Xfq2k3uPfB602NwQF+zAY2GQZiKUN1G7Q6XiDCUR8Y=",
         | 
| 103 | 
            -
                    "owner": "huggingface",
         | 
| 104 | 
            -
                    "repo": "kernel-builder",
         | 
| 105 | 
            -
                    "rev": "bbc4e712ff2046e217818e97de2201e2b996756e",
         | 
| 106 | 
            -
                    "type": "github"
         | 
| 107 | 
            -
                  },
         | 
| 108 | 
            -
                  "original": {
         | 
| 109 | 
            -
                    "owner": "huggingface",
         | 
| 110 | 
            -
                    "repo": "kernel-builder",
         | 
| 111 | 
            -
                    "type": "github"
         | 
| 112 | 
            -
                  }
         | 
| 113 | 
            -
                },
         | 
| 114 | 
            -
                "nixpkgs": {
         | 
| 115 | 
            -
                  "locked": {
         | 
| 116 | 
            -
                    "lastModified": 1747820358,
         | 
| 117 | 
            -
                    "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
         | 
| 118 | 
            -
                    "owner": "danieldk",
         | 
| 119 | 
            -
                    "repo": "nixpkgs",
         | 
| 120 | 
            -
                    "rev": "d3c1681180717528068082103bf323147de6ab0b",
         | 
| 121 | 
            -
                    "type": "github"
         | 
| 122 | 
            -
                  },
         | 
| 123 | 
            -
                  "original": {
         | 
| 124 | 
            -
                    "owner": "danieldk",
         | 
| 125 | 
            -
                    "ref": "cudatoolkit-12.9-kernel-builder",
         | 
| 126 | 
            -
                    "repo": "nixpkgs",
         | 
| 127 | 
            -
                    "type": "github"
         | 
| 128 | 
            -
                  }
         | 
| 129 | 
            -
                },
         | 
| 130 | 
            -
                "root": {
         | 
| 131 | 
            -
                  "inputs": {
         | 
| 132 | 
            -
                    "kernel-builder": "kernel-builder"
         | 
| 133 | 
            -
                  }
         | 
| 134 | 
            -
                },
         | 
| 135 | 
            -
                "systems": {
         | 
| 136 | 
            -
                  "locked": {
         | 
| 137 | 
            -
                    "lastModified": 1681028828,
         | 
| 138 | 
            -
                    "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
         | 
| 139 | 
            -
                    "owner": "nix-systems",
         | 
| 140 | 
            -
                    "repo": "default",
         | 
| 141 | 
            -
                    "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
         | 
| 142 | 
            -
                    "type": "github"
         | 
| 143 | 
            -
                  },
         | 
| 144 | 
            -
                  "original": {
         | 
| 145 | 
            -
                    "owner": "nix-systems",
         | 
| 146 | 
            -
                    "repo": "default",
         | 
| 147 | 
            -
                    "type": "github"
         | 
| 148 | 
            -
                  }
         | 
| 149 | 
            -
                },
         | 
| 150 | 
            -
                "systems_2": {
         | 
| 151 | 
            -
                  "locked": {
         | 
| 152 | 
            -
                    "lastModified": 1681028828,
         | 
| 153 | 
            -
                    "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
         | 
| 154 | 
            -
                    "owner": "nix-systems",
         | 
| 155 | 
            -
                    "repo": "default",
         | 
| 156 | 
            -
                    "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
         | 
| 157 | 
            -
                    "type": "github"
         | 
| 158 | 
            -
                  },
         | 
| 159 | 
            -
                  "original": {
         | 
| 160 | 
            -
                    "owner": "nix-systems",
         | 
| 161 | 
            -
                    "repo": "default",
         | 
| 162 | 
            -
                    "type": "github"
         | 
| 163 | 
            -
                  }
         | 
| 164 | 
            -
                }
         | 
| 165 | 
            -
              },
         | 
| 166 | 
            -
              "root": "root",
         | 
| 167 | 
            -
              "version": 7
         | 
| 168 | 
            -
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flake.nix
    DELETED
    
    | @@ -1,51 +0,0 @@ | |
| 1 | 
            -
            {
         | 
| 2 | 
            -
              description = "Flake for Hopper Flash Attention kernel";
         | 
| 3 | 
            -
             | 
| 4 | 
            -
              inputs = {
         | 
| 5 | 
            -
                kernel-builder.url = "github:huggingface/kernel-builder";
         | 
| 6 | 
            -
              };
         | 
| 7 | 
            -
             | 
| 8 | 
            -
              outputs =
         | 
| 9 | 
            -
                {
         | 
| 10 | 
            -
                  self,
         | 
| 11 | 
            -
                  kernel-builder,
         | 
| 12 | 
            -
                }:
         | 
| 13 | 
            -
                kernel-builder.lib.genFlakeOutputs {
         | 
| 14 | 
            -
                  path = ./.;
         | 
| 15 | 
            -
                  rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
         | 
| 16 | 
            -
                  # Building with CDUA later than 12.4 fails with:
         | 
| 17 | 
            -
                  #
         | 
| 18 | 
            -
                  # error: 'ptxas' died due to signal 11 (Invalid memory reference)
         | 
| 19 | 
            -
                  #
         | 
| 20 | 
            -
                  # So, build for 12.4 only and copy to all the other build variants
         | 
| 21 | 
            -
                  # by hand (which works fine thanks to backward compat).
         | 
| 22 | 
            -
                  #
         | 
| 23 | 
            -
                  # Still need to check if upstream FA3 has the same issue.
         | 
| 24 | 
            -
                  torchVersions = [
         | 
| 25 | 
            -
                    {
         | 
| 26 | 
            -
                      torchVersion = "2.6";
         | 
| 27 | 
            -
                      cudaVersion = "12.4";
         | 
| 28 | 
            -
                      cxx11Abi = false;
         | 
| 29 | 
            -
                      systems = [ "x86_64-linux" ];
         | 
| 30 | 
            -
                      upstreamVariant = true;
         | 
| 31 | 
            -
                    }
         | 
| 32 | 
            -
                    {
         | 
| 33 | 
            -
                      torchVersion = "2.6";
         | 
| 34 | 
            -
                      cudaVersion = "12.4";
         | 
| 35 | 
            -
                      cxx11Abi = true;
         | 
| 36 | 
            -
                      systems = [ "x86_64-linux" ];
         | 
| 37 | 
            -
                      upstreamVariant = true;
         | 
| 38 | 
            -
                    }
         | 
| 39 | 
            -
                    {
         | 
| 40 | 
            -
                      torchVersion = "2.7";
         | 
| 41 | 
            -
                      cudaVersion = "12.4";
         | 
| 42 | 
            -
                      cxx11Abi = true;
         | 
| 43 | 
            -
                      systems = [
         | 
| 44 | 
            -
                        "x86_64-linux"
         | 
| 45 | 
            -
                        "aarch64-linux"
         | 
| 46 | 
            -
                      ];
         | 
| 47 | 
            -
                      upstreamVariant = true;
         | 
| 48 | 
            -
                    }
         | 
| 49 | 
            -
                  ];
         | 
| 50 | 
            -
                };
         | 
| 51 | 
            -
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/block.h
    DELETED
    
    | @@ -1,94 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            namespace flash {
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            template <class SeqlenInfo_t, int kBlockM, int kBlockN, bool Is_causal, bool Is_local, bool PackGQA=false, bool Split=false>
         | 
| 10 | 
            -
            struct BlockMN {
         | 
| 11 | 
            -
             | 
| 12 | 
            -
                static
         | 
| 13 | 
            -
                CUTLASS_DEVICE
         | 
| 14 | 
            -
                cute::tuple<int, int> get_n_block_min_max(
         | 
| 15 | 
            -
                        SeqlenInfo_t const& seqlen_info,
         | 
| 16 | 
            -
                        int const m_block, int const bidb, int const split_idx, int const num_splits,
         | 
| 17 | 
            -
                        int const window_size_left, int const window_size_right,
         | 
| 18 | 
            -
                        cutlass::FastDivmod const& qhead_per_khead_divmod) {
         | 
| 19 | 
            -
             | 
| 20 | 
            -
                    int const seqlen_k = seqlen_info.seqlen_k;
         | 
| 21 | 
            -
                    int const seqlen_q = seqlen_info.seqlen_q;
         | 
| 22 | 
            -
                    int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
         | 
| 23 | 
            -
                    if constexpr (Is_causal || Is_local) {
         | 
| 24 | 
            -
                        int m_idx_max = (m_block + 1) * kBlockM;
         | 
| 25 | 
            -
                        // TODO: check off-by-1 error
         | 
| 26 | 
            -
                        if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
         | 
| 27 | 
            -
                        n_block_max = std::min(n_block_max,
         | 
| 28 | 
            -
                                               cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN));
         | 
| 29 | 
            -
                    }
         | 
| 30 | 
            -
                    int n_block_min = 0;
         | 
| 31 | 
            -
                    if constexpr (Is_local) {
         | 
| 32 | 
            -
                        int m_idx_min = m_block * kBlockM;
         | 
| 33 | 
            -
                        if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
         | 
| 34 | 
            -
                        n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - window_size_left) / kBlockN);
         | 
| 35 | 
            -
                    }
         | 
| 36 | 
            -
                    // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
         | 
| 37 | 
            -
                    if constexpr (Split) {
         | 
| 38 | 
            -
                        uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
         | 
| 39 | 
            -
                        int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
         | 
| 40 | 
            -
                        int split_idx_actual = split_idx & 0x0000FFFF;
         | 
| 41 | 
            -
                        int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
         | 
| 42 | 
            -
                        int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual);
         | 
| 43 | 
            -
                        n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split;
         | 
| 44 | 
            -
                        n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max);
         | 
| 45 | 
            -
                        // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); }
         | 
| 46 | 
            -
                    }
         | 
| 47 | 
            -
                    // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
         | 
| 48 | 
            -
                    return {n_block_min, n_block_max};
         | 
| 49 | 
            -
                }
         | 
| 50 | 
            -
             | 
| 51 | 
            -
                static
         | 
| 52 | 
            -
                CUTLASS_DEVICE
         | 
| 53 | 
            -
                cute::tuple<int, int> get_n_block_k_new_min_max(
         | 
| 54 | 
            -
                        SeqlenInfo_t const& seqlen_info,
         | 
| 55 | 
            -
                        int const m_block, int const bidb, int const split_idx, int const num_splits,
         | 
| 56 | 
            -
                        int const window_size_left, int const window_size_right,
         | 
| 57 | 
            -
                        cutlass::FastDivmod const& qhead_per_khead_divmod) {
         | 
| 58 | 
            -
             | 
| 59 | 
            -
                    auto [n_block_min, n_block_max] = get_n_block_min_max(
         | 
| 60 | 
            -
                        seqlen_info, m_block, bidb, split_idx, num_splits,
         | 
| 61 | 
            -
                        window_size_left, window_size_right, qhead_per_khead_divmod);
         | 
| 62 | 
            -
                    int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
         | 
| 63 | 
            -
                    int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
         | 
| 64 | 
            -
                    int const n_block_new_min = idx_k_new_min / kBlockN;
         | 
| 65 | 
            -
                    int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
         | 
| 66 | 
            -
                    // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
         | 
| 67 | 
            -
                    return {n_block_new_min, n_block_new_max};
         | 
| 68 | 
            -
                }
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                static
         | 
| 71 | 
            -
                CUTLASS_DEVICE
         | 
| 72 | 
            -
                cute::tuple<int, int> get_m_block_min_max(
         | 
| 73 | 
            -
                        SeqlenInfo_t const& seqlen_info,
         | 
| 74 | 
            -
                        int const n_block, int const bidb,
         | 
| 75 | 
            -
                        int const window_size_left, int const window_size_right, int const sink_token_length) {
         | 
| 76 | 
            -
             | 
| 77 | 
            -
                    int const seqlen_q = seqlen_info.seqlen_q;
         | 
| 78 | 
            -
                    int const seqlen_k = seqlen_info.seqlen_k;
         | 
| 79 | 
            -
                    int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
         | 
| 80 | 
            -
                    if constexpr (Is_local) {
         | 
| 81 | 
            -
                        if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) {
         | 
| 82 | 
            -
                            m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM));
         | 
| 83 | 
            -
                        }
         | 
| 84 | 
            -
                    }
         | 
| 85 | 
            -
                    int m_block_min = 0;
         | 
| 86 | 
            -
                    if constexpr (Is_causal || Is_local) {
         | 
| 87 | 
            -
                        m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM);
         | 
| 88 | 
            -
                    }
         | 
| 89 | 
            -
                    return {m_block_min, m_block_max};
         | 
| 90 | 
            -
                }
         | 
| 91 | 
            -
             | 
| 92 | 
            -
            };
         | 
| 93 | 
            -
             | 
| 94 | 
            -
            } // namespace flash
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/copy_sm90_bulk_reduce.hpp
    DELETED
    
    | @@ -1,49 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include<cute/arch/copy_sm90_tma.hpp>
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            namespace cute
         | 
| 10 | 
            -
            {
         | 
| 11 | 
            -
             | 
| 12 | 
            -
            ////////////////////////////////////////////////////////////////////////////////////////////////////
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            struct SM90_BULK_REDUCE_ADD
         | 
| 15 | 
            -
            {
         | 
| 16 | 
            -
              CUTE_HOST_DEVICE static void
         | 
| 17 | 
            -
              copy(float const* smem_ptr,
         | 
| 18 | 
            -
                   float      * gmem_ptr, int32_t store_bytes)
         | 
| 19 | 
            -
              {
         | 
| 20 | 
            -
            #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
         | 
| 21 | 
            -
                uint32_t smem_int_ptr  = cast_smem_ptr_to_uint(smem_ptr);
         | 
| 22 | 
            -
                asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
         | 
| 23 | 
            -
                                 :
         | 
| 24 | 
            -
                                 : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes)
         | 
| 25 | 
            -
                                 : "memory");
         | 
| 26 | 
            -
            #else
         | 
| 27 | 
            -
                CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
         | 
| 28 | 
            -
            #endif
         | 
| 29 | 
            -
              }
         | 
| 30 | 
            -
             | 
| 31 | 
            -
              CUTE_HOST_DEVICE static void
         | 
| 32 | 
            -
              copy(float const* smem_ptr,
         | 
| 33 | 
            -
                   float      * gmem_ptr, int32_t store_bytes, uint64_t cache_hint)
         | 
| 34 | 
            -
              {
         | 
| 35 | 
            -
            #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
         | 
| 36 | 
            -
                uint32_t smem_int_ptr  = cast_smem_ptr_to_uint(smem_ptr);
         | 
| 37 | 
            -
                asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\n"
         | 
| 38 | 
            -
                                 :
         | 
| 39 | 
            -
                                 : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes), "l"(cache_hint)
         | 
| 40 | 
            -
                                 : "memory");
         | 
| 41 | 
            -
            #else
         | 
| 42 | 
            -
                CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
         | 
| 43 | 
            -
            #endif
         | 
| 44 | 
            -
              }
         | 
| 45 | 
            -
            };
         | 
| 46 | 
            -
             | 
| 47 | 
            -
            ////////////////////////////////////////////////////////////////////////////////////////////////////
         | 
| 48 | 
            -
             | 
| 49 | 
            -
            } // end namespace cute
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/cuda_check.h
    DELETED
    
    | @@ -1,19 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include <assert.h>
         | 
| 8 | 
            -
            #include <stdlib.h>
         | 
| 9 | 
            -
             | 
| 10 | 
            -
            #define CHECK_CUDA(call)                        \
         | 
| 11 | 
            -
                do {                                                                                                  \
         | 
| 12 | 
            -
                    cudaError_t status_ = call;                                                                       \
         | 
| 13 | 
            -
                    if (status_ != cudaSuccess) {                                                                     \
         | 
| 14 | 
            -
                        fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
         | 
| 15 | 
            -
                        exit(1);                                                                                      \
         | 
| 16 | 
            -
                    }                                                                                                 \
         | 
| 17 | 
            -
                } while(0)
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/epilogue_bwd.hpp
    DELETED
    
    | @@ -1,523 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include "cutlass/cutlass.h"
         | 
| 8 | 
            -
            #include "cutlass/barrier.h"
         | 
| 9 | 
            -
            #include "cute/tensor.hpp"
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            #include "cutlass/gemm/collective/builders/sm90_common.inl"
         | 
| 12 | 
            -
             | 
| 13 | 
            -
            #include "seqlen.h"
         | 
| 14 | 
            -
            #include "named_barrier.hpp"
         | 
| 15 | 
            -
            #include "utils.h"
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            namespace flash {
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            using namespace cute;
         | 
| 20 | 
            -
             | 
| 21 | 
            -
            template <class TileShape_MNK_, class Element_, class ArchTag_,
         | 
| 22 | 
            -
                      int NumEpilogueThreads_, bool Varlen_, bool dKV_swapAB_, int AtomLayoutKdKV=1>
         | 
| 23 | 
            -
            struct CollectiveEpilogueBwd {
         | 
| 24 | 
            -
             | 
| 25 | 
            -
                using TileShape_MNK = TileShape_MNK_;
         | 
| 26 | 
            -
                using Element = Element_;
         | 
| 27 | 
            -
                using ArchTag = ArchTag_;
         | 
| 28 | 
            -
                static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
         | 
| 29 | 
            -
                static constexpr bool Varlen = Varlen_;
         | 
| 30 | 
            -
                static constexpr bool dKV_swapAB = dKV_swapAB_;
         | 
| 31 | 
            -
                static constexpr bool Use_TMA = !Varlen && ArchTag::kMinComputeCapability >= 90;
         | 
| 32 | 
            -
             | 
| 33 | 
            -
                static_assert(ArchTag::kMinComputeCapability >= 80);
         | 
| 34 | 
            -
             | 
| 35 | 
            -
                using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE;
         | 
| 36 | 
            -
             | 
| 37 | 
            -
                // These are for storing the output tensor without TMA (e.g., for setting output to zero)
         | 
| 38 | 
            -
                static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
         | 
| 39 | 
            -
                static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
         | 
| 40 | 
            -
                static constexpr int kHeadDim = get<2>(TileShape_MNK{});
         | 
| 41 | 
            -
                static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads);
         | 
| 42 | 
            -
                static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
         | 
| 43 | 
            -
                using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
         | 
| 44 | 
            -
                                              Stride<Int<kGmemThreadsPerRow>, _1>>;
         | 
| 45 | 
            -
                using GmemTiledCopydKV = decltype(
         | 
| 46 | 
            -
                    make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
         | 
| 47 | 
            -
                                    GmemLayoutAtom{},
         | 
| 48 | 
            -
                                    Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per store
         | 
| 49 | 
            -
             | 
| 50 | 
            -
                using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
         | 
| 51 | 
            -
                                                      // TODO: do we have to change this if dKV_swapAB is true?
         | 
| 52 | 
            -
                                                      decltype(cute::get<1>(TileShape_MNK{})), Int<CUTE_STATIC_V(cute::get<2>(TileShape_MNK{})) / AtomLayoutKdKV>>());
         | 
| 53 | 
            -
                using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{})));
         | 
| 54 | 
            -
                using SmemLayoutdKVtTMA =
         | 
| 55 | 
            -
                    decltype(cute::composition(SmemLayoutdKVTMA{},
         | 
| 56 | 
            -
                                               make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
         | 
| 57 | 
            -
                                                           make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
         | 
| 58 | 
            -
             | 
| 59 | 
            -
                // If we don't use TMA
         | 
| 60 | 
            -
                static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16);
         | 
| 61 | 
            -
                static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
         | 
| 62 | 
            -
                using SmemLayoutAtomdKVSTG =
         | 
| 63 | 
            -
                    decltype(composition(Swizzle<kSwizzle, 3, 3>{},
         | 
| 64 | 
            -
                                         Layout<Shape<Int<8>, Int<kBlockKSmem>>,
         | 
| 65 | 
            -
                                         Stride<Int<kBlockKSmem>, _1>>{}));
         | 
| 66 | 
            -
             | 
| 67 | 
            -
                using SmemLayoutAtomdKV = std::conditional_t<Use_TMA, SmemLayoutAtomdKVTMA, SmemLayoutAtomdKVSTG>;
         | 
| 68 | 
            -
                using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{})));
         | 
| 69 | 
            -
                using SmemLayoutdKVt =
         | 
| 70 | 
            -
                    decltype(cute::composition(SmemLayoutdKV{},
         | 
| 71 | 
            -
                                               make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
         | 
| 72 | 
            -
                                                           make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                using SmemCopyAtomdKV = Copy_Atom<
         | 
| 75 | 
            -
                    std::conditional_t<
         | 
| 76 | 
            -
                        ArchTag::kMinComputeCapability >= 90,
         | 
| 77 | 
            -
                        std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
         | 
| 78 | 
            -
                        AutoVectorizingCopyWithAssumedAlignment<128>
         | 
| 79 | 
            -
                    >,
         | 
| 80 | 
            -
                    Element>;
         | 
| 81 | 
            -
             | 
| 82 | 
            -
                static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128;
         | 
| 83 | 
            -
                static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment");
         | 
| 84 | 
            -
             | 
| 85 | 
            -
                struct TensorStorage : cute::aligned_struct<SmemAlignmentdKV> {
         | 
| 86 | 
            -
                    cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dk;
         | 
| 87 | 
            -
                    cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dv;
         | 
| 88 | 
            -
                };
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen_k, d, head, batch)
         | 
| 91 | 
            -
                using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                using TMA_dKV = std::conditional_t<
         | 
| 94 | 
            -
                    Use_TMA,
         | 
| 95 | 
            -
                    decltype(make_tma_copy(
         | 
| 96 | 
            -
                        GmemTiledCopydKVTMA{},
         | 
| 97 | 
            -
                        make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapedKV{}, StridedKV{}),
         | 
| 98 | 
            -
                        SmemLayoutdKVTMA{},
         | 
| 99 | 
            -
                        select<1, 2>(TileShape_MNK{}),
         | 
| 100 | 
            -
                        _1{})),  // no mcast for dKV
         | 
| 101 | 
            -
                    std::nullptr_t
         | 
| 102 | 
            -
                    >;
         | 
| 103 | 
            -
             | 
| 104 | 
            -
                // Host side kernel arguments
         | 
| 105 | 
            -
                struct Arguments {
         | 
| 106 | 
            -
                    Element* ptr_dK;
         | 
| 107 | 
            -
                    ShapedKV const shape_dK;
         | 
| 108 | 
            -
                    StridedKV const stride_dK;
         | 
| 109 | 
            -
                    Element* ptr_dV;
         | 
| 110 | 
            -
                    StridedKV const stride_dV;
         | 
| 111 | 
            -
                    int const num_heads_q;
         | 
| 112 | 
            -
                    int* dk_semaphore;
         | 
| 113 | 
            -
                    int* dv_semaphore;
         | 
| 114 | 
            -
                    int const* cu_seqlens;
         | 
| 115 | 
            -
                    int const* seqused;
         | 
| 116 | 
            -
                };
         | 
| 117 | 
            -
             | 
| 118 | 
            -
                // Device side kernel params
         | 
| 119 | 
            -
                struct Params {
         | 
| 120 | 
            -
                    Element* ptr_dK;
         | 
| 121 | 
            -
                    ShapedKV const shape_dK;
         | 
| 122 | 
            -
                    StridedKV const stride_dK;
         | 
| 123 | 
            -
                    Element* ptr_dV;
         | 
| 124 | 
            -
                    StridedKV const stride_dV;
         | 
| 125 | 
            -
                    TMA_dKV tma_store_dK, tma_store_dV;
         | 
| 126 | 
            -
                    int const* cu_seqlens = nullptr;
         | 
| 127 | 
            -
                    int const* seqused = nullptr;
         | 
| 128 | 
            -
                };
         | 
| 129 | 
            -
             | 
| 130 | 
            -
                static Params
         | 
| 131 | 
            -
                to_underlying_arguments(Arguments const& args) {
         | 
| 132 | 
            -
                    Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK);
         | 
| 133 | 
            -
                    Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dK, args.stride_dV);
         | 
| 134 | 
            -
                    TMA_dKV tma_store_dK = [&] {
         | 
| 135 | 
            -
                        if constexpr (Use_TMA) {
         | 
| 136 | 
            -
                            return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
         | 
| 137 | 
            -
                        } else {
         | 
| 138 | 
            -
                            return nullptr;
         | 
| 139 | 
            -
                        }
         | 
| 140 | 
            -
                    }();
         | 
| 141 | 
            -
                    TMA_dKV tma_store_dV = [&] {
         | 
| 142 | 
            -
                        if constexpr (Use_TMA) {
         | 
| 143 | 
            -
                            return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
         | 
| 144 | 
            -
                        } else {
         | 
| 145 | 
            -
                            return nullptr;
         | 
| 146 | 
            -
                        }
         | 
| 147 | 
            -
                    }();
         | 
| 148 | 
            -
                    return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV,
         | 
| 149 | 
            -
                            tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
         | 
| 150 | 
            -
                }
         | 
| 151 | 
            -
             | 
| 152 | 
            -
                /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
         | 
| 153 | 
            -
                CUTLASS_DEVICE
         | 
| 154 | 
            -
                static void prefetch_tma_descriptors(Params const& params) {
         | 
| 155 | 
            -
                    if constexpr (Use_TMA) {
         | 
| 156 | 
            -
                        cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor());
         | 
| 157 | 
            -
                        cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor());
         | 
| 158 | 
            -
                    }
         | 
| 159 | 
            -
                }
         | 
| 160 | 
            -
             | 
| 161 | 
            -
                template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
         | 
| 162 | 
            -
                CUTLASS_DEVICE void
         | 
| 163 | 
            -
                store(Params const& params,
         | 
| 164 | 
            -
                      FrgTensorO const& tdKrdK,
         | 
| 165 | 
            -
                      FrgTensorO const& tdVrdV,
         | 
| 166 | 
            -
                      SharedStorage& shared_storage,
         | 
| 167 | 
            -
                      TiledMma tiled_mma,
         | 
| 168 | 
            -
                      int thread_idx,
         | 
| 169 | 
            -
                      cute::tuple<int32_t, int32_t, int32_t> const& block_coord
         | 
| 170 | 
            -
                      ) {
         | 
| 171 | 
            -
             | 
| 172 | 
            -
                    auto [n_block, bidh, bidb] = block_coord;
         | 
| 173 | 
            -
                    Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{}));
         | 
| 174 | 
            -
                    Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{}));
         | 
| 175 | 
            -
                    Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{}));
         | 
| 176 | 
            -
                    Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{}));
         | 
| 177 | 
            -
                    auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma);
         | 
| 178 | 
            -
                    auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx);
         | 
| 179 | 
            -
             | 
| 180 | 
            -
                    Tensor tdVrdV_out = make_tensor_like<Element>(tdVrdV);
         | 
| 181 | 
            -
                    flash::convert_type_out(tdVrdV, tdVrdV_out);
         | 
| 182 | 
            -
                    Tensor tdKrdK_out = make_tensor_like<Element>(tdKrdK);
         | 
| 183 | 
            -
                    flash::convert_type_out(tdKrdK, tdKrdK_out);
         | 
| 184 | 
            -
                    Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out);        // ((Atom,AtomNum), MMA_M, MMA_N)
         | 
| 185 | 
            -
                    Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out);        // ((Atom,AtomNum), MMA_M, MMA_N)
         | 
| 186 | 
            -
                    // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); }
         | 
| 187 | 
            -
                    Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdK, sdKt));     // ((Atom,AtomNum),PIPE_M,PIPE_N)
         | 
| 188 | 
            -
                    Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdV, sdVt));     // ((Atom,AtomNum),PIPE_M,PIPE_N)
         | 
| 189 | 
            -
             | 
| 190 | 
            -
                    // Make sure all WGs have finished reading K and V
         | 
| 191 | 
            -
                    flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
         | 
| 192 | 
            -
                    cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
         | 
| 193 | 
            -
                    cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
         | 
| 194 | 
            -
                    if constexpr (Use_TMA) {
         | 
| 195 | 
            -
                        cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
         | 
| 196 | 
            -
                        cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
         | 
| 197 | 
            -
                                                            cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
         | 
| 198 | 
            -
             | 
| 199 | 
            -
                        Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK);
         | 
| 200 | 
            -
                        Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dK);
         | 
| 201 | 
            -
                        Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)
         | 
| 202 | 
            -
                        Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)
         | 
| 203 | 
            -
                        auto block_tma_dK = params.tma_store_dK.get_slice(_0{});
         | 
| 204 | 
            -
                        auto block_tma_dV = params.tma_store_dV.get_slice(_0{});
         | 
| 205 | 
            -
                        Tensor tdKgdK = block_tma_dK.partition_D(gdK);  // (TMA, TMA_M, TMA_K)
         | 
| 206 | 
            -
                        Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
         | 
| 207 | 
            -
                        Tensor tdVgdV = block_tma_dV.partition_D(gdV);  // (TMA, TMA_M, TMA_K)
         | 
| 208 | 
            -
                        Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
         | 
| 209 | 
            -
                        int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
         | 
| 210 | 
            -
                        if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
         | 
| 211 | 
            -
                            cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
         | 
| 212 | 
            -
                                                            cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
         | 
| 213 | 
            -
                            if (cute::elect_one_sync()) {
         | 
| 214 | 
            -
                                cute::copy(params.tma_store_dV, tdVsdV, tdVgdV);
         | 
| 215 | 
            -
                                cute::copy(params.tma_store_dK, tdKsdK, tdKgdK);
         | 
| 216 | 
            -
                                tma_store_arrive();
         | 
| 217 | 
            -
                            }
         | 
| 218 | 
            -
                        }
         | 
| 219 | 
            -
                        tma_store_wait<0>();
         | 
| 220 | 
            -
                        // // Tell warp 0 that smem_k and smem_v are ready
         | 
| 221 | 
            -
                        // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
         | 
| 222 | 
            -
             | 
| 223 | 
            -
                    } else {
         | 
| 224 | 
            -
                        flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
         | 
| 225 | 
            -
                        static constexpr int kBlockN = get<1>(TileShape_MNK{});
         | 
| 226 | 
            -
                        flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
         | 
| 227 | 
            -
                        bool const is_varlen = Varlen && params.cu_seqlens;
         | 
| 228 | 
            -
                        Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
         | 
| 229 | 
            -
                        Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)
         | 
| 230 | 
            -
                        Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
         | 
| 231 | 
            -
                        Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)
         | 
| 232 | 
            -
             | 
| 233 | 
            -
                        GmemTiledCopydKV gmem_tiled_copy_dKV;
         | 
| 234 | 
            -
                        auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
         | 
| 235 | 
            -
                        Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
         | 
| 236 | 
            -
                        Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
         | 
| 237 | 
            -
                        Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
         | 
| 238 | 
            -
                        Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K)
         | 
| 239 | 
            -
                        Tensor tdKVrdV = make_fragment_like(tdKVgdV);
         | 
| 240 | 
            -
                        Tensor tdKVrdK = make_fragment_like(tdKVgdK);
         | 
| 241 | 
            -
                        Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{}));  // (BLK_N,BLK_K) -> (blk_n,blk_k)
         | 
| 242 | 
            -
                        // Repeat the partitioning with identity layouts
         | 
| 243 | 
            -
                        Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
         | 
| 244 | 
            -
                        Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
         | 
| 245 | 
            -
                        #pragma unroll
         | 
| 246 | 
            -
                        for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
         | 
| 247 | 
            -
                        // Need to check OOB when reading from smem if kBlockN isn't evenly tiled
         | 
| 248 | 
            -
                        static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
         | 
| 249 | 
            -
                        flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
         | 
| 250 | 
            -
                            gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdKV, kBlockN);
         | 
| 251 | 
            -
                        flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
         | 
| 252 | 
            -
                            gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdKV, kBlockN);
         | 
| 253 | 
            -
                        // // Tell warp 0 that smem_k and smem_v are ready
         | 
| 254 | 
            -
                        // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v
         | 
| 255 | 
            -
                        // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
         | 
| 256 | 
            -
                        // Construct identity layout for gdKV
         | 
| 257 | 
            -
                        // Clear_OOB_K must be false since we don't want to write zeros to gmem
         | 
| 258 | 
            -
                        flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         | 
| 259 | 
            -
                            gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
         | 
| 260 | 
            -
                        );
         | 
| 261 | 
            -
                        flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         | 
| 262 | 
            -
                            gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
         | 
| 263 | 
            -
                        );
         | 
| 264 | 
            -
                    }
         | 
| 265 | 
            -
                }
         | 
| 266 | 
            -
             | 
| 267 | 
            -
                CUTLASS_DEVICE void
         | 
| 268 | 
            -
                store_tail() {
         | 
| 269 | 
            -
                    // if constexpr (Use_TMA) { tma_store_wait<0>(); }
         | 
| 270 | 
            -
                }
         | 
| 271 | 
            -
             | 
| 272 | 
            -
                // Write 0 to dK and dV
         | 
| 273 | 
            -
                CUTLASS_DEVICE void
         | 
| 274 | 
            -
                store_zero(
         | 
| 275 | 
            -
                     Params const& params,
         | 
| 276 | 
            -
                     int thread_idx,
         | 
| 277 | 
            -
                     cute::tuple<int32_t, int32_t, int32_t> const& block_coord
         | 
| 278 | 
            -
                     ) {
         | 
| 279 | 
            -
                    static constexpr int kBlockN = get<1>(TileShape_MNK{});
         | 
| 280 | 
            -
                    auto [n_block, bidh, bidb] = block_coord;
         | 
| 281 | 
            -
                    flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
         | 
| 282 | 
            -
                    bool const is_varlen = Varlen && params.cu_seqlens;
         | 
| 283 | 
            -
                    Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
         | 
| 284 | 
            -
                    Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)
         | 
| 285 | 
            -
                    Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
         | 
| 286 | 
            -
                    Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)
         | 
| 287 | 
            -
             | 
| 288 | 
            -
                    GmemTiledCopydKV gmem_tiled_copy_dKV;
         | 
| 289 | 
            -
                    auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
         | 
| 290 | 
            -
                    Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
         | 
| 291 | 
            -
                    Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
         | 
| 292 | 
            -
                    Tensor tdKVrdKV = make_fragment_like(tdKVgdK);
         | 
| 293 | 
            -
                    clear(tdKVrdKV);
         | 
| 294 | 
            -
                    // Construct identity layout for gdKV
         | 
| 295 | 
            -
                    Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{}));  // (BLK_M,BLK_K) -> (blk_m,blk_k)
         | 
| 296 | 
            -
                    // Repeat the partitioning with identity layouts
         | 
| 297 | 
            -
                    Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
         | 
| 298 | 
            -
                    Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
         | 
| 299 | 
            -
                    #pragma unroll
         | 
| 300 | 
            -
                    for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
         | 
| 301 | 
            -
                    // Clear_OOB_K must be false since we don't want to write zeros to gmem
         | 
| 302 | 
            -
                    flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         | 
| 303 | 
            -
                        gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN
         | 
| 304 | 
            -
                    );
         | 
| 305 | 
            -
                    flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         | 
| 306 | 
            -
                        gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN
         | 
| 307 | 
            -
                    );
         | 
| 308 | 
            -
                }
         | 
| 309 | 
            -
             | 
| 310 | 
            -
            };
         | 
| 311 | 
            -
             | 
| 312 | 
            -
            template <class TileShape_MNK_, class ElementAccum, class ArchTag_,
         | 
| 313 | 
            -
                      int NumEpilogueThreads_, bool Varlen_, bool Deterministic>
         | 
| 314 | 
            -
            struct CollectiveEpilogueBwdGQA {
         | 
| 315 | 
            -
             | 
| 316 | 
            -
                using TileShape_MNK = TileShape_MNK_;
         | 
| 317 | 
            -
                using Element = ElementAccum;
         | 
| 318 | 
            -
                using ArchTag = ArchTag_;
         | 
| 319 | 
            -
                static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
         | 
| 320 | 
            -
                static constexpr bool Varlen = Varlen_;
         | 
| 321 | 
            -
                static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90;
         | 
| 322 | 
            -
             | 
| 323 | 
            -
                static_assert(ArchTag::kMinComputeCapability >= 80);
         | 
| 324 | 
            -
             | 
| 325 | 
            -
                static constexpr int kBlockN = get<1>(TileShape_MNK{});
         | 
| 326 | 
            -
                static constexpr int kHeadDim = get<2>(TileShape_MNK{});
         | 
| 327 | 
            -
                static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp");
         | 
| 328 | 
            -
                static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup;
         | 
| 329 | 
            -
                // Thread layout, 256 or 384 threads per row
         | 
| 330 | 
            -
                // We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ
         | 
| 331 | 
            -
                using R2SLayoutAtomdKVaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumWarpGroups>>>;
         | 
| 332 | 
            -
                using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdKVaccum{},
         | 
| 333 | 
            -
                                                                     Layout<Shape < _4>>{}));  // Val layout, 4 vals per store
         | 
| 334 | 
            -
                // For Sm80
         | 
| 335 | 
            -
                using R2GLayoutAtomdKVaccum = Layout<Shape<Int<NumEpilogueThreads>>>;
         | 
| 336 | 
            -
                using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2GLayoutAtomdKVaccum{},
         | 
| 337 | 
            -
                                                                     Layout<Shape < _1>>{}));  // Val layout, 1 vals per store
         | 
| 338 | 
            -
             | 
| 339 | 
            -
                using SmemLayoutdKVaccum = Layout<Shape<Int<kBlockN * kHeadDim / NumWarpGroups>, Int<NumWarpGroups>>>;
         | 
| 340 | 
            -
                using SmemLayoutdKVaccumFlat = Layout<Shape<Int<kBlockN * kHeadDim>>>;
         | 
| 341 | 
            -
             | 
| 342 | 
            -
                // Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we
         | 
| 343 | 
            -
                // only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue.
         | 
| 344 | 
            -
                static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256);
         | 
| 345 | 
            -
                struct TensorStorageTMA : cute::aligned_struct<SmemAlignment> {
         | 
| 346 | 
            -
                    cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdKVaccum>, SmemAlignment> smem_dkv;
         | 
| 347 | 
            -
                };
         | 
| 348 | 
            -
                struct TensorStorageSTG {
         | 
| 349 | 
            -
                    cute::array<ElementAccum, 0> smem_dkv;
         | 
| 350 | 
            -
                };
         | 
| 351 | 
            -
                using TensorStorage = std::conditional_t<Use_TMA, TensorStorageTMA, TensorStorageSTG>;
         | 
| 352 | 
            -
             | 
| 353 | 
            -
                using ShapedKV = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen_k_rounded * d, head, batch)
         | 
| 354 | 
            -
                using StridedKV = cute::Stride<_1, int64_t, int64_t>;
         | 
| 355 | 
            -
             | 
| 356 | 
            -
                // Host side kernel arguments
         | 
| 357 | 
            -
                struct Arguments {
         | 
| 358 | 
            -
                    ElementAccum* ptr_dKaccum;
         | 
| 359 | 
            -
                    ShapedKV const shape_dKaccum;
         | 
| 360 | 
            -
                    StridedKV const stride_dKaccum;
         | 
| 361 | 
            -
                    ElementAccum* ptr_dVaccum;
         | 
| 362 | 
            -
                    StridedKV const stride_dVaccum;
         | 
| 363 | 
            -
                    int num_heads_q;
         | 
| 364 | 
            -
                    int* dk_semaphore;
         | 
| 365 | 
            -
                    int* dv_semaphore;
         | 
| 366 | 
            -
                    int const* cu_seqlens;
         | 
| 367 | 
            -
                    int const* seqused;
         | 
| 368 | 
            -
                };
         | 
| 369 | 
            -
             | 
| 370 | 
            -
                // Device side kernel params
         | 
| 371 | 
            -
                struct Params {
         | 
| 372 | 
            -
                    ElementAccum* ptr_dKaccum;
         | 
| 373 | 
            -
                    ShapedKV const shape_dKaccum;
         | 
| 374 | 
            -
                    StridedKV const stride_dKaccum;
         | 
| 375 | 
            -
                    ElementAccum* ptr_dVaccum;
         | 
| 376 | 
            -
                    StridedKV const stride_dVaccum;
         | 
| 377 | 
            -
                    cutlass::FastDivmod qhead_per_khead_divmod;
         | 
| 378 | 
            -
                    int* dk_semaphore;
         | 
| 379 | 
            -
                    int* dv_semaphore;
         | 
| 380 | 
            -
                    int const* cu_seqlens = nullptr;
         | 
| 381 | 
            -
                    int const* seqused = nullptr;
         | 
| 382 | 
            -
                };
         | 
| 383 | 
            -
             | 
| 384 | 
            -
                static Params
         | 
| 385 | 
            -
                to_underlying_arguments(Arguments const& args) {
         | 
| 386 | 
            -
                    if constexpr (Deterministic) {
         | 
| 387 | 
            -
                        assert(args.dk_semaphore != nullptr);
         | 
| 388 | 
            -
                        assert(args.dv_semaphore != nullptr);
         | 
| 389 | 
            -
                    }
         | 
| 390 | 
            -
                    return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.stride_dVaccum,
         | 
| 391 | 
            -
                            cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))),
         | 
| 392 | 
            -
                            args.dk_semaphore, args.dv_semaphore,
         | 
| 393 | 
            -
                            args.cu_seqlens, args.seqused};
         | 
| 394 | 
            -
                }
         | 
| 395 | 
            -
             | 
| 396 | 
            -
                /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
         | 
| 397 | 
            -
                CUTLASS_DEVICE
         | 
| 398 | 
            -
                static void prefetch_tma_descriptors(Params const& params) {
         | 
| 399 | 
            -
                }
         | 
| 400 | 
            -
             | 
| 401 | 
            -
                template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
         | 
| 402 | 
            -
                CUTLASS_DEVICE void
         | 
| 403 | 
            -
                store(Params const& params,
         | 
| 404 | 
            -
                      FrgTensorO const& tdKrdK,
         | 
| 405 | 
            -
                      FrgTensorO const& tdVrdV,
         | 
| 406 | 
            -
                      SharedStorage& shared_storage,
         | 
| 407 | 
            -
                      TiledMma tiled_mma,
         | 
| 408 | 
            -
                      int thread_idx,
         | 
| 409 | 
            -
                      cute::tuple<int32_t, int32_t, int32_t> const& block_coord
         | 
| 410 | 
            -
                      ) {
         | 
| 411 | 
            -
             | 
| 412 | 
            -
                    auto [n_block, bidh, bidb] = block_coord;
         | 
| 413 | 
            -
                    int bidh_idx_in_group;
         | 
| 414 | 
            -
                    int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh);
         | 
| 415 | 
            -
                    Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{});
         | 
| 416 | 
            -
                    Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{});
         | 
| 417 | 
            -
                    static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum);
         | 
| 418 | 
            -
             | 
| 419 | 
            -
                    flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused};
         | 
| 420 | 
            -
                    bool const is_varlen = Varlen && params.cu_seqlens;
         | 
| 421 | 
            -
                    Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
         | 
| 422 | 
            -
                    Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dKaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
         | 
| 423 | 
            -
                    Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block));  // (M * K)
         | 
| 424 | 
            -
                    Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block));  // (M * K)
         | 
| 425 | 
            -
             | 
| 426 | 
            -
                    R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum;
         | 
| 427 | 
            -
                    auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
         | 
| 428 | 
            -
                    Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV);
         | 
| 429 | 
            -
             | 
| 430 | 
            -
                    // Only used if !Use_TMA
         | 
| 431 | 
            -
                    R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum;
         | 
| 432 | 
            -
                    auto r2g_thr_copy_dKVaccum = r2g_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
         | 
| 433 | 
            -
             | 
| 434 | 
            -
                    // Make sure all WGs have finished reading K and V, otherwise we get racy dQ
         | 
| 435 | 
            -
                    // because smem_q could be changed.
         | 
| 436 | 
            -
                    flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
         | 
| 437 | 
            -
                    if constexpr (Use_TMA) {
         | 
| 438 | 
            -
                        Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N)
         | 
| 439 | 
            -
                        cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum);
         | 
| 440 | 
            -
                    }
         | 
| 441 | 
            -
             | 
| 442 | 
            -
                    // int const num_batch = params.num_batch;
         | 
| 443 | 
            -
                    int const num_batch = get<2>(params.shape_dKaccum);
         | 
| 444 | 
            -
                    int const num_head_kv = get<1>(params.shape_dKaccum);
         | 
| 445 | 
            -
                    int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv;
         | 
| 446 | 
            -
                    using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
         | 
| 447 | 
            -
             | 
| 448 | 
            -
                    // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
         | 
| 449 | 
            -
             | 
| 450 | 
            -
                    if constexpr (Deterministic) {
         | 
| 451 | 
            -
                        Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
         | 
| 452 | 
            -
                    }
         | 
| 453 | 
            -
                    // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);}
         | 
| 454 | 
            -
                    if constexpr (Use_TMA) {
         | 
| 455 | 
            -
                        cutlass::arch::fence_view_async_shared();
         | 
| 456 | 
            -
                        cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
         | 
| 457 | 
            -
                        if (thread_idx == 0) {
         | 
| 458 | 
            -
                            SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
         | 
| 459 | 
            -
                            tma_store_arrive();
         | 
| 460 | 
            -
                            tma_store_wait<0>();
         | 
| 461 | 
            -
                        }
         | 
| 462 | 
            -
                    } else {
         | 
| 463 | 
            -
                        Tensor tdVrdV_atomic = r2g_thr_copy_dKVaccum.retile_S(tdVrdV);
         | 
| 464 | 
            -
                        Tensor tdVgdV_atomic = r2g_thr_copy_dKVaccum.partition_D(gdVaccum);
         | 
| 465 | 
            -
                        static_assert(CUTE_STATIC_V(size(tdVrdV_atomic)) == CUTE_STATIC_V(size(tdVgdV_atomic)));
         | 
| 466 | 
            -
                        #pragma unroll
         | 
| 467 | 
            -
                        for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdV_atomic(i), tdVrdV_atomic(i)); }
         | 
| 468 | 
            -
                    }
         | 
| 469 | 
            -
                    if constexpr (Deterministic) {
         | 
| 470 | 
            -
                        Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
         | 
| 471 | 
            -
                    }
         | 
| 472 | 
            -
             | 
| 473 | 
            -
                    if constexpr (Use_TMA) {
         | 
| 474 | 
            -
                        cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
         | 
| 475 | 
            -
                        Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N)
         | 
| 476 | 
            -
                        cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum);
         | 
| 477 | 
            -
                    }
         | 
| 478 | 
            -
                    lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv;
         | 
| 479 | 
            -
                    // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
         | 
| 480 | 
            -
             | 
| 481 | 
            -
                    if constexpr (Deterministic) {
         | 
| 482 | 
            -
                        Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
         | 
| 483 | 
            -
                    }
         | 
| 484 | 
            -
                    // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);}
         | 
| 485 | 
            -
                    if constexpr (Use_TMA) {
         | 
| 486 | 
            -
                        cutlass::arch::fence_view_async_shared();
         | 
| 487 | 
            -
                        cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
         | 
| 488 | 
            -
                        if (thread_idx == 0) {
         | 
| 489 | 
            -
                            SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
         | 
| 490 | 
            -
                            tma_store_arrive();
         | 
| 491 | 
            -
                            tma_store_wait<0>();
         | 
| 492 | 
            -
                        }
         | 
| 493 | 
            -
                    } else {
         | 
| 494 | 
            -
                        Tensor tdKrdK_atomic = r2g_thr_copy_dKVaccum.retile_S(tdKrdK);
         | 
| 495 | 
            -
                        Tensor tdKgdK_atomic = r2g_thr_copy_dKVaccum.partition_D(gdKaccum);
         | 
| 496 | 
            -
                        static_assert(CUTE_STATIC_V(size(tdKrdK_atomic)) == CUTE_STATIC_V(size(tdKgdK_atomic)));
         | 
| 497 | 
            -
                        #pragma unroll
         | 
| 498 | 
            -
                        for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdK_atomic(i), tdKrdK_atomic(i)); }
         | 
| 499 | 
            -
                    }
         | 
| 500 | 
            -
                    if constexpr (Deterministic) {
         | 
| 501 | 
            -
                        Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
         | 
| 502 | 
            -
                    }
         | 
| 503 | 
            -
                    // // Tell warp 0 that smem_k and smem_v are ready
         | 
| 504 | 
            -
                    // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
         | 
| 505 | 
            -
                }
         | 
| 506 | 
            -
             | 
| 507 | 
            -
                CUTLASS_DEVICE void
         | 
| 508 | 
            -
                store_tail() {
         | 
| 509 | 
            -
                }
         | 
| 510 | 
            -
             | 
| 511 | 
            -
                // Write 0 to dK and dV
         | 
| 512 | 
            -
                CUTLASS_DEVICE void
         | 
| 513 | 
            -
                store_zero(
         | 
| 514 | 
            -
                     Params const& params,
         | 
| 515 | 
            -
                     int thread_idx,
         | 
| 516 | 
            -
                     cute::tuple<int32_t, int32_t, int32_t> const& block_coord
         | 
| 517 | 
            -
                     ) {
         | 
| 518 | 
            -
                    // Don't need to do anything since dKaccum and dVaccum are already zero-initialized
         | 
| 519 | 
            -
                }
         | 
| 520 | 
            -
             | 
| 521 | 
            -
            };
         | 
| 522 | 
            -
             | 
| 523 | 
            -
            } // namespace flash
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/epilogue_fwd.hpp
    DELETED
    
    | @@ -1,484 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include <cutlass/cutlass.h>
         | 
| 8 | 
            -
            #include <cutlass/fast_math.h>  // For FastDivMod
         | 
| 9 | 
            -
            #include "cute/tensor.hpp"
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            #include "cutlass/gemm/collective/builders/sm90_common.inl"
         | 
| 12 | 
            -
            #include "cutlass/epilogue/collective/builders/sm90_common.inl"
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            #include "seqlen.h"
         | 
| 15 | 
            -
            #include "named_barrier.hpp"
         | 
| 16 | 
            -
            #include "pack_gqa.h"
         | 
| 17 | 
            -
            #include "utils.h"
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            namespace flash {
         | 
| 20 | 
            -
             | 
| 21 | 
            -
            using namespace cute;
         | 
| 22 | 
            -
             | 
| 23 | 
            -
            template <class TileShape_MNK_PV_, class ClusterShape_, class Element_, class ArchTag_,
         | 
| 24 | 
            -
                      int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false>
         | 
| 25 | 
            -
            struct CollectiveEpilogueFwd {
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                using TileShape_MNK_PV = TileShape_MNK_PV_;
         | 
| 28 | 
            -
                using ClusterShape = ClusterShape_;
         | 
| 29 | 
            -
                using Element = Element_;
         | 
| 30 | 
            -
                using ElementPartial = float;
         | 
| 31 | 
            -
                using ArchTag = ArchTag_;
         | 
| 32 | 
            -
                static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
         | 
| 33 | 
            -
                static constexpr bool Varlen = Varlen_;
         | 
| 34 | 
            -
                static constexpr bool PackGQA = PackGQA_;
         | 
| 35 | 
            -
                static constexpr bool Split = Split_;
         | 
| 36 | 
            -
                static constexpr bool Use_smem = !(Split && !Varlen);
         | 
| 37 | 
            -
                static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA;
         | 
| 38 | 
            -
             | 
| 39 | 
            -
                static_assert(ArchTag::kMinComputeCapability >= 80);
         | 
| 40 | 
            -
                static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);
         | 
| 41 | 
            -
                static_assert(sizeof(Element) <= 2);
         | 
| 42 | 
            -
             | 
| 43 | 
            -
                static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
         | 
| 44 | 
            -
                static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{});
         | 
| 45 | 
            -
             | 
| 46 | 
            -
                static constexpr bool LargeHeadDimV = kHeadDimV > 256;
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
         | 
| 49 | 
            -
             | 
| 50 | 
            -
                // These are for storing the output tensor without TMA (e.g., for setting output to zero)
         | 
| 51 | 
            -
                static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element);
         | 
| 52 | 
            -
                static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore");
         | 
| 53 | 
            -
                // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements
         | 
| 54 | 
            -
                // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times
         | 
| 55 | 
            -
                // we need to call divmod.
         | 
| 56 | 
            -
                static constexpr int kBytePerRow = kHeadDimV * sizeof(Element);
         | 
| 57 | 
            -
                static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
         | 
| 58 | 
            -
                static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore;
         | 
| 59 | 
            -
                // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp
         | 
| 60 | 
            -
                static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0);
         | 
| 61 | 
            -
                static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
         | 
| 62 | 
            -
                using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
         | 
| 63 | 
            -
                                              Stride<Int<kGmemThreadsPerRow>, _1>>;
         | 
| 64 | 
            -
                static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow");
         | 
| 65 | 
            -
                using GmemTiledCopyO = decltype(
         | 
| 66 | 
            -
                    make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
         | 
| 67 | 
            -
                                    GmemLayoutAtom{},
         | 
| 68 | 
            -
                                    Layout<Shape<_1, Int<kGmemElemsPerStore>>>{}));  // Val layout, 8 or 16 vals per store
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
         | 
| 71 | 
            -
                    decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>());
         | 
| 72 | 
            -
                using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{})));
         | 
| 73 | 
            -
                static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
         | 
| 74 | 
            -
                static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
         | 
| 75 | 
            -
                using SmemLayoutAtomO = decltype(
         | 
| 76 | 
            -
                    composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
         | 
| 77 | 
            -
                                Layout<Shape<_8, Int<kBlockKGmem>>,
         | 
| 78 | 
            -
                                       Stride<Int<kBlockKGmem>, _1>>{}));
         | 
| 79 | 
            -
                using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{})));
         | 
| 80 | 
            -
                using SmemLayoutO = std::conditional_t<ArchTag::kMinComputeCapability >= 90, SmemLayoutOTMA, SmemLayoutOSTS>;
         | 
| 81 | 
            -
             | 
| 82 | 
            -
                using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>;  // (seqlen_q, d, head, batch, num_splits)
         | 
| 83 | 
            -
                using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
         | 
| 84 | 
            -
                using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>;            // (seqlen_q, head, batch, num_splits)
         | 
| 85 | 
            -
                // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
         | 
| 86 | 
            -
                using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
         | 
| 87 | 
            -
                using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>;
         | 
| 88 | 
            -
                // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
         | 
| 89 | 
            -
                using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
         | 
| 90 | 
            -
                using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>;
         | 
| 91 | 
            -
             | 
| 92 | 
            -
                using CopyOpR2S = std::conditional_t<
         | 
| 93 | 
            -
                    ArchTag::kMinComputeCapability >= 90,
         | 
| 94 | 
            -
                    // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)
         | 
| 95 | 
            -
                    decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>()),
         | 
| 96 | 
            -
                    AutoVectorizingCopyWithAssumedAlignment<128>
         | 
| 97 | 
            -
                >;
         | 
| 98 | 
            -
                using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;
         | 
| 99 | 
            -
             | 
| 100 | 
            -
                // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{});
         | 
| 101 | 
            -
                // static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment");
         | 
| 102 | 
            -
                // struct TensorStorage : cute::aligned_struct<SmemAlignmentO> {
         | 
| 103 | 
            -
                //     cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0, SmemAlignmentO> smem_o;
         | 
| 104 | 
            -
                // };
         | 
| 105 | 
            -
                struct TensorStorage : cute::aligned_struct<128> {
         | 
| 106 | 
            -
                    cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0> smem_o;
         | 
| 107 | 
            -
                };
         | 
| 108 | 
            -
             | 
| 109 | 
            -
                using TMA_O = std::conditional_t<
         | 
| 110 | 
            -
                    Use_TMA_O,
         | 
| 111 | 
            -
                    decltype(make_tma_copy(
         | 
| 112 | 
            -
                        GmemTiledCopyOTMA{},
         | 
| 113 | 
            -
                        make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
         | 
| 114 | 
            -
                        SmemLayoutOTMA{},
         | 
| 115 | 
            -
                        select<0, 1>(TileShape_MNK_PV{}),
         | 
| 116 | 
            -
                        _1{})),  // no mcast for O
         | 
| 117 | 
            -
                    std::nullptr_t
         | 
| 118 | 
            -
                >;
         | 
| 119 | 
            -
             | 
| 120 | 
            -
                // Host side kernel arguments
         | 
| 121 | 
            -
                struct Arguments {
         | 
| 122 | 
            -
                    Element* ptr_O;
         | 
| 123 | 
            -
                    ShapeO const shape_O;
         | 
| 124 | 
            -
                    StrideO const stride_O;
         | 
| 125 | 
            -
                    ElementPartial* ptr_O_partial;
         | 
| 126 | 
            -
                    StrideO const stride_O_partial;
         | 
| 127 | 
            -
                    float* ptr_LSE;
         | 
| 128 | 
            -
                    StrideLSE const stride_LSE;
         | 
| 129 | 
            -
                    float* ptr_LSE_partial;
         | 
| 130 | 
            -
                    StrideLSE const stride_LSE_partial;
         | 
| 131 | 
            -
                    int32_t const nheads_kv;
         | 
| 132 | 
            -
                    int const* cu_seqlens = nullptr;
         | 
| 133 | 
            -
                    int const* seqused = nullptr;
         | 
| 134 | 
            -
                };
         | 
| 135 | 
            -
             | 
| 136 | 
            -
                // Device side kernel params
         | 
| 137 | 
            -
                struct Params {
         | 
| 138 | 
            -
                    Element* ptr_O;
         | 
| 139 | 
            -
                    ShapeO const shape_O;
         | 
| 140 | 
            -
                    StrideO const stride_O;
         | 
| 141 | 
            -
                    ShapeOPacked const shape_O_packed;
         | 
| 142 | 
            -
                    StrideOPacked const stride_O_packed;
         | 
| 143 | 
            -
                    ElementPartial* ptr_O_partial;
         | 
| 144 | 
            -
                    StrideO const stride_O_partial;
         | 
| 145 | 
            -
                    StrideOPacked const stride_O_partial_packed;
         | 
| 146 | 
            -
                    float* ptr_LSE;
         | 
| 147 | 
            -
                    StrideLSE const stride_LSE;
         | 
| 148 | 
            -
                    ShapeLSEPacked const shape_LSE_packed;
         | 
| 149 | 
            -
                    StrideLSEPacked const stride_LSE_packed;
         | 
| 150 | 
            -
                    float* ptr_LSE_partial;
         | 
| 151 | 
            -
                    StrideLSE const stride_LSE_partial;
         | 
| 152 | 
            -
                    StrideLSEPacked const stride_LSE_partial_packed;
         | 
| 153 | 
            -
                    cutlass::FastDivmod qhead_per_khead_divmod;
         | 
| 154 | 
            -
                    TMA_O tma_store_O;
         | 
| 155 | 
            -
                    int const* cu_seqlens = nullptr;
         | 
| 156 | 
            -
                    int const* seqused = nullptr;
         | 
| 157 | 
            -
                };
         | 
| 158 | 
            -
             | 
| 159 | 
            -
                static Params
         | 
| 160 | 
            -
                to_underlying_arguments(Arguments const& args) {
         | 
| 161 | 
            -
                    Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
         | 
| 162 | 
            -
                    TMA_O tma_store_O = [&]{
         | 
| 163 | 
            -
                        if constexpr (Use_TMA_O) {
         | 
| 164 | 
            -
                            return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast
         | 
| 165 | 
            -
                        } else {
         | 
| 166 | 
            -
                            return nullptr;
         | 
| 167 | 
            -
                        }
         | 
| 168 | 
            -
                    }();
         | 
| 169 | 
            -
                    // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits)
         | 
| 170 | 
            -
                    int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv);
         | 
| 171 | 
            -
                    auto const shape_O_packed = cute::conditional_return<!PackGQA>(
         | 
| 172 | 
            -
                        args.shape_O,
         | 
| 173 | 
            -
                        make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
         | 
| 174 | 
            -
                    );
         | 
| 175 | 
            -
                    auto const stride_O_packed = cute::conditional_return<!PackGQA>(
         | 
| 176 | 
            -
                        args.stride_O,
         | 
| 177 | 
            -
                        make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O))
         | 
| 178 | 
            -
                    );
         | 
| 179 | 
            -
                    auto const stride_O_partial_packed = cute::conditional_return<!PackGQA>(
         | 
| 180 | 
            -
                        args.stride_O_partial,
         | 
| 181 | 
            -
                        make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial))
         | 
| 182 | 
            -
                    );
         | 
| 183 | 
            -
                    // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits)
         | 
| 184 | 
            -
                    auto const shape_LSE_packed = cute::conditional_return<!PackGQA>(
         | 
| 185 | 
            -
                        select<0, 2, 3, 4>(args.shape_O),
         | 
| 186 | 
            -
                        make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
         | 
| 187 | 
            -
                    );
         | 
| 188 | 
            -
                    auto const stride_LSE_packed = cute::conditional_return<!PackGQA>(
         | 
| 189 | 
            -
                        args.stride_LSE,
         | 
| 190 | 
            -
                        make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE))
         | 
| 191 | 
            -
                    );
         | 
| 192 | 
            -
                    auto const stride_LSE_partial_packed = cute::conditional_return<!PackGQA>(
         | 
| 193 | 
            -
                        args.stride_LSE_partial,
         | 
| 194 | 
            -
                        make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial))
         | 
| 195 | 
            -
                    );
         | 
| 196 | 
            -
                    return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed,
         | 
| 197 | 
            -
                            args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed,
         | 
| 198 | 
            -
                            args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed,
         | 
| 199 | 
            -
                            args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed,
         | 
| 200 | 
            -
                            cutlass::FastDivmod(qhead_per_khead),
         | 
| 201 | 
            -
                            tma_store_O, args.cu_seqlens, args.seqused};
         | 
| 202 | 
            -
                }
         | 
| 203 | 
            -
             | 
| 204 | 
            -
                /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
         | 
| 205 | 
            -
                CUTLASS_DEVICE
         | 
| 206 | 
            -
                static void prefetch_tma_descriptors(Params const& params) {
         | 
| 207 | 
            -
                    if constexpr (Use_TMA_O) {
         | 
| 208 | 
            -
                        cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor());
         | 
| 209 | 
            -
                    }
         | 
| 210 | 
            -
                }
         | 
| 211 | 
            -
             | 
| 212 | 
            -
                template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
         | 
| 213 | 
            -
                CUTLASS_DEVICE void
         | 
| 214 | 
            -
                store(Params const& params,
         | 
| 215 | 
            -
                      FrgTensorO& tOrO,
         | 
| 216 | 
            -
                      FrgTensorLSE const& lse,
         | 
| 217 | 
            -
                      SharedStorage& shared_storage,
         | 
| 218 | 
            -
                      TiledMma tiled_mma,
         | 
| 219 | 
            -
                      int thread_idx,
         | 
| 220 | 
            -
                      cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
         | 
| 221 | 
            -
                      ) {
         | 
| 222 | 
            -
             | 
| 223 | 
            -
                    auto [m_block, bidh, bidb, split_idx] = block_coord;
         | 
| 224 | 
            -
                    int num_splits = get<4>(params.shape_O_packed);
         | 
| 225 | 
            -
                    if constexpr (Split && Varlen) {
         | 
| 226 | 
            -
                        uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
         | 
| 227 | 
            -
                        int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
         | 
| 228 | 
            -
                        num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
         | 
| 229 | 
            -
                        split_idx &= 0x0000FFFF;  // Only use the lower 16 bits of split_idx
         | 
| 230 | 
            -
                    }
         | 
| 231 | 
            -
                    bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
         | 
| 232 | 
            -
             | 
| 233 | 
            -
                    Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{});
         | 
| 234 | 
            -
                    // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO);
         | 
| 235 | 
            -
             | 
| 236 | 
            -
                    static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4);
         | 
| 237 | 
            -
                    // If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion.
         | 
| 238 | 
            -
                    // Otherwise we can permute after conversion.
         | 
| 239 | 
            -
                    if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); }
         | 
| 240 | 
            -
                    Tensor tOrO_out = make_tensor_like<Element>(tOrO);
         | 
| 241 | 
            -
                    flash::convert_type_out(tOrO, tOrO_out);
         | 
| 242 | 
            -
                    if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); }
         | 
| 243 | 
            -
             | 
| 244 | 
            -
                    // Make sure all WGs have finished reading V
         | 
| 245 | 
            -
                    // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that
         | 
| 246 | 
            -
                    // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with
         | 
| 247 | 
            -
                    // cp.async if we need).
         | 
| 248 | 
            -
                    flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
         | 
| 249 | 
            -
             | 
| 250 | 
            -
                    // Step 1: Write O from rmem -> smem
         | 
| 251 | 
            -
                    if constexpr (Use_smem) {
         | 
| 252 | 
            -
                        auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
         | 
| 253 | 
            -
                        auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
         | 
| 254 | 
            -
                        Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);        // ((Atom,AtomNum), MMA_M, MMA_N)
         | 
| 255 | 
            -
                        Tensor taccOsO = smem_thr_copy_O.partition_D(sO);     // ((Atom,AtomNum),PIPE_M,PIPE_N)
         | 
| 256 | 
            -
                        // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi);     // ((Atom,AtomNum),PIPE_M,PIPE_N)
         | 
| 257 | 
            -
                        cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
         | 
| 258 | 
            -
                        if constexpr (Use_TMA_O) {
         | 
| 259 | 
            -
                            cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
         | 
| 260 | 
            -
                            cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
         | 
| 261 | 
            -
                                                                cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
         | 
| 262 | 
            -
                        } else {
         | 
| 263 | 
            -
                            flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
         | 
| 264 | 
            -
                        }
         | 
| 265 | 
            -
                    } else {
         | 
| 266 | 
            -
                        if constexpr (ArchTag::kMinComputeCapability >= 90) {
         | 
| 267 | 
            -
                            #pragma unroll
         | 
| 268 | 
            -
                            for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
         | 
| 269 | 
            -
                                shared_storage.pipelines.barrier_O.arrive(cta_id);
         | 
| 270 | 
            -
                            }
         | 
| 271 | 
            -
                        }
         | 
| 272 | 
            -
                    }
         | 
| 273 | 
            -
             | 
| 274 | 
            -
                    flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
         | 
| 275 | 
            -
                    bool is_varlen = Varlen && params.cu_seqlens;
         | 
| 276 | 
            -
                    int offset_o = seqlen_info.offset;
         | 
| 277 | 
            -
                    int seqlen_o = seqlen_info.seqlen;
         | 
| 278 | 
            -
                    int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
         | 
| 279 | 
            -
             | 
| 280 | 
            -
                    // Step 2: Write LSE from rmem -> gmem
         | 
| 281 | 
            -
                    auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
         | 
| 282 | 
            -
                    // (MMA,MMA_M,MMA_K)
         | 
| 283 | 
            -
                    Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
         | 
| 284 | 
            -
                    static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
         | 
| 285 | 
            -
                    static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
         | 
| 286 | 
            -
                    Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));
         | 
| 287 | 
            -
                    Tensor taccOcO_row = taccOcO_rowcol(_, _0{});
         | 
| 288 | 
            -
                    CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));                     // MMA_M
         | 
| 289 | 
            -
             | 
| 290 | 
            -
                    using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
         | 
| 291 | 
            -
                    using PackGQApartial_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>;
         | 
| 292 | 
            -
             | 
| 293 | 
            -
                    Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
         | 
| 294 | 
            -
                                              params.shape_LSE_packed,
         | 
| 295 | 
            -
                                              !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
         | 
| 296 | 
            -
                    // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); }
         | 
| 297 | 
            -
                    if (!LargeHeadDimV || warp_group_idx == 0) {
         | 
| 298 | 
            -
                        if constexpr (!PackGQA) {
         | 
| 299 | 
            -
                            #pragma unroll
         | 
| 300 | 
            -
                            for (int mi = 0; mi < size(lse); ++mi) {
         | 
| 301 | 
            -
                                int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));
         | 
| 302 | 
            -
                                if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); }
         | 
| 303 | 
            -
                            }
         | 
| 304 | 
            -
                        } else {
         | 
| 305 | 
            -
                            PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
         | 
| 306 | 
            -
                        }
         | 
| 307 | 
            -
                    }
         | 
| 308 | 
            -
             | 
| 309 | 
            -
                    // Step 3: Write O from smem -> gmem
         | 
| 310 | 
            -
                    if constexpr (Use_TMA_O) {
         | 
| 311 | 
            -
                        Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
         | 
| 312 | 
            -
                        Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{}));  // (M, K)
         | 
| 313 | 
            -
                        auto block_tma_O = params.tma_store_O.get_slice(_0{});
         | 
| 314 | 
            -
                        Tensor tOgO = block_tma_O.partition_D(gO);  // (TMA, TMA_M, TMA_K)
         | 
| 315 | 
            -
                        Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
         | 
| 316 | 
            -
                        int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
         | 
| 317 | 
            -
                        if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
         | 
| 318 | 
            -
                            cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
         | 
| 319 | 
            -
                                                              cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
         | 
| 320 | 
            -
                            if (cute::elect_one_sync()) {
         | 
| 321 | 
            -
                                cute::copy(params.tma_store_O, tOsO, tOgO);
         | 
| 322 | 
            -
                                tma_store_arrive();
         | 
| 323 | 
            -
                                tma_store_wait<0>();
         | 
| 324 | 
            -
                                #pragma unroll
         | 
| 325 | 
            -
                                for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
         | 
| 326 | 
            -
                                    shared_storage.pipelines.barrier_O.arrive(cta_id);
         | 
| 327 | 
            -
                                }
         | 
| 328 | 
            -
                            }
         | 
| 329 | 
            -
                        }
         | 
| 330 | 
            -
                    } else {  // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence
         | 
| 331 | 
            -
                        if (!is_split) {
         | 
| 332 | 
            -
                            Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
         | 
| 333 | 
            -
                            Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{}));  // (M, K)
         | 
| 334 | 
            -
                            // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast<int>(&mO(0)) - reinterpret_cast<int>(params.ptr_O)); }
         | 
| 335 | 
            -
                            GmemTiledCopyO gmem_tiled_copy_O;
         | 
| 336 | 
            -
                            auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
         | 
| 337 | 
            -
                            Tensor tOsO = gmem_thr_copy_O.partition_S(sO);        // ((Atom,AtomNum),ATOM_M,ATOM_N)
         | 
| 338 | 
            -
                            // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi);        // ((Atom,AtomNum),ATOM_M,ATOM_N)
         | 
| 339 | 
            -
                            Tensor tOrO = make_fragment_like(tOsO);
         | 
| 340 | 
            -
                            cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
         | 
| 341 | 
            -
                            if constexpr (ArchTag::kMinComputeCapability >= 90) {
         | 
| 342 | 
            -
                                cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v
         | 
| 343 | 
            -
                                #pragma unroll
         | 
| 344 | 
            -
                                for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
         | 
| 345 | 
            -
                                    shared_storage.pipelines.barrier_O.arrive(cta_id);
         | 
| 346 | 
            -
                                }
         | 
| 347 | 
            -
                            }
         | 
| 348 | 
            -
                            if constexpr (!PackGQA) {
         | 
| 349 | 
            -
                                // (BLK_M,BLK_K) -> (blk_m,blk_k)
         | 
| 350 | 
            -
                                Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
         | 
| 351 | 
            -
                                Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOsO)));
         | 
| 352 | 
            -
                                #pragma unroll
         | 
| 353 | 
            -
                                for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
         | 
| 354 | 
            -
                                Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
         | 
| 355 | 
            -
                                // Clear_OOB_K must be false since we don't want to write zeros to gmem
         | 
| 356 | 
            -
                                flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         | 
| 357 | 
            -
                                    gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
         | 
| 358 | 
            -
                                );
         | 
| 359 | 
            -
                            } else {
         | 
| 360 | 
            -
                                // If PackGQA, we split the work of compute O_ptr among threads in the same row
         | 
| 361 | 
            -
                                PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
         | 
| 362 | 
            -
                            }
         | 
| 363 | 
            -
                        } else {
         | 
| 364 | 
            -
                            Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
         | 
| 365 | 
            -
                            Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{}));  // (M, K)
         | 
| 366 | 
            -
                            // We already arrived on barrier_O earlier if !Use_smem
         | 
| 367 | 
            -
                            if constexpr (Use_smem) {
         | 
| 368 | 
            -
                                if constexpr (ArchTag::kMinComputeCapability >= 90) {
         | 
| 369 | 
            -
                                    #pragma unroll
         | 
| 370 | 
            -
                                    for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
         | 
| 371 | 
            -
                                        shared_storage.pipelines.barrier_O.arrive(cta_id);
         | 
| 372 | 
            -
                                    }
         | 
| 373 | 
            -
                                }
         | 
| 374 | 
            -
                            }
         | 
| 375 | 
            -
                            if constexpr (!PackGQA) {
         | 
| 376 | 
            -
                                static constexpr int kGmemElemsPerStoreDirect = 2;
         | 
| 377 | 
            -
                                cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial> gmem_copy_direct;
         | 
| 378 | 
            -
                                // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
         | 
| 379 | 
            -
                                Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));
         | 
| 380 | 
            -
                                Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
         | 
| 381 | 
            -
                                Tensor tOgO = thread_mma.partition_C(gOpartial);
         | 
| 382 | 
            -
                                Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout()));
         | 
| 383 | 
            -
                                Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
         | 
| 384 | 
            -
                                Tensor taccOcO_col = taccOcO_rowcol(_0{}, _);
         | 
| 385 | 
            -
                                #pragma unroll
         | 
| 386 | 
            -
                                for (int m = 0; m < size(taccOcO_row); ++m) {
         | 
| 387 | 
            -
                                    if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) {
         | 
| 388 | 
            -
                                        #pragma unroll
         | 
| 389 | 
            -
                                        for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) {
         | 
| 390 | 
            -
                                            if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) {
         | 
| 391 | 
            -
                                                cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k));
         | 
| 392 | 
            -
                                            }
         | 
| 393 | 
            -
                                        }
         | 
| 394 | 
            -
                                    }
         | 
| 395 | 
            -
                                }
         | 
| 396 | 
            -
                            } else {
         | 
| 397 | 
            -
                                PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
         | 
| 398 | 
            -
                            }
         | 
| 399 | 
            -
                        }
         | 
| 400 | 
            -
                    }
         | 
| 401 | 
            -
                }
         | 
| 402 | 
            -
             | 
| 403 | 
            -
                CUTLASS_DEVICE void
         | 
| 404 | 
            -
                store_tail() {
         | 
| 405 | 
            -
                    // Don't need to do tma_store_wait<0>() here since we already did in @store
         | 
| 406 | 
            -
                }
         | 
| 407 | 
            -
             | 
| 408 | 
            -
                // Write 0 to output and -inf to LSE
         | 
| 409 | 
            -
                CUTLASS_DEVICE void
         | 
| 410 | 
            -
                store_zero(
         | 
| 411 | 
            -
                     Params const& params,
         | 
| 412 | 
            -
                     int thread_idx,
         | 
| 413 | 
            -
                     cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
         | 
| 414 | 
            -
                     ) {
         | 
| 415 | 
            -
                    static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
         | 
| 416 | 
            -
                    auto [m_block, bidh, bidb, split_idx] = block_coord;
         | 
| 417 | 
            -
                    int num_splits = get<4>(params.shape_O_packed);
         | 
| 418 | 
            -
                    if constexpr (Split && Varlen) {
         | 
| 419 | 
            -
                        uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
         | 
| 420 | 
            -
                        int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
         | 
| 421 | 
            -
                        num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
         | 
| 422 | 
            -
                        split_idx &= 0x0000FFFF;  // Only use the lower 16 bits of split_idx
         | 
| 423 | 
            -
                    }
         | 
| 424 | 
            -
                    bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
         | 
| 425 | 
            -
             | 
| 426 | 
            -
                    flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
         | 
| 427 | 
            -
                    bool const is_varlen = Varlen && params.cu_seqlens;
         | 
| 428 | 
            -
                    int offset_o = seqlen_info.offset;
         | 
| 429 | 
            -
                    int seqlen_o = seqlen_info.seqlen;
         | 
| 430 | 
            -
                    int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
         | 
| 431 | 
            -
                    Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
         | 
| 432 | 
            -
                                              params.shape_LSE_packed,
         | 
| 433 | 
            -
                                              !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
         | 
| 434 | 
            -
                    Tensor gLSE = local_tile(mLSE, Shape<Int<kBlockM>>{}, make_coord(m_block));
         | 
| 435 | 
            -
             | 
| 436 | 
            -
                    static_assert(kBlockM <= NumEpilogueThreads);
         | 
| 437 | 
            -
                    if (thread_idx < kBlockM) {
         | 
| 438 | 
            -
                        const int row = m_block * kBlockM + thread_idx;
         | 
| 439 | 
            -
                        if constexpr (!PackGQA) {
         | 
| 440 | 
            -
                            if (row < seqlen_o) { mLSE(row) = -INFINITY; }
         | 
| 441 | 
            -
                        } else {
         | 
| 442 | 
            -
                            if (row < seqlen_o * qhead_per_khead) {
         | 
| 443 | 
            -
                                int m_idx, h_idx;
         | 
| 444 | 
            -
                                m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row);
         | 
| 445 | 
            -
                                // mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord"
         | 
| 446 | 
            -
                                mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY;
         | 
| 447 | 
            -
                            }
         | 
| 448 | 
            -
                        }
         | 
| 449 | 
            -
                    }
         | 
| 450 | 
            -
             | 
| 451 | 
            -
                    // If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used,
         | 
| 452 | 
            -
                    // since it will not use the value of O if LSE is -inf.
         | 
| 453 | 
            -
                    if (!is_split) {
         | 
| 454 | 
            -
                        Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
         | 
| 455 | 
            -
             | 
| 456 | 
            -
                        GmemTiledCopyO gmem_tiled_copy_O;
         | 
| 457 | 
            -
                        auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
         | 
| 458 | 
            -
                        Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
         | 
| 459 | 
            -
                        if constexpr (!PackGQA) {
         | 
| 460 | 
            -
                            Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
         | 
| 461 | 
            -
                            #pragma unroll
         | 
| 462 | 
            -
                            for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
         | 
| 463 | 
            -
                            Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{}));  // (M, K)
         | 
| 464 | 
            -
                            Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
         | 
| 465 | 
            -
                            Tensor tOrO = make_fragment_like(tOgO);
         | 
| 466 | 
            -
                            cute::clear(tOrO);
         | 
| 467 | 
            -
                            // Clear_OOB_K must be false since we don't want to write zeros to gmem
         | 
| 468 | 
            -
                            flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         | 
| 469 | 
            -
                                gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
         | 
| 470 | 
            -
                            );
         | 
| 471 | 
            -
                        } else {
         | 
| 472 | 
            -
                            // If PackGQA, we split the work of compute O_ptr among threads in the same row
         | 
| 473 | 
            -
                            using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
         | 
| 474 | 
            -
                            Tensor tOrO = make_tensor<Element>(make_shape(Shape<_1, Int<kGmemElemsPerStore>>{}, size<1>(tOcO), size<2>(tOcO)));
         | 
| 475 | 
            -
                            cute::clear(tOrO);
         | 
| 476 | 
            -
                            PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
         | 
| 477 | 
            -
                        }
         | 
| 478 | 
            -
                    }
         | 
| 479 | 
            -
             | 
| 480 | 
            -
                }
         | 
| 481 | 
            -
             | 
| 482 | 
            -
            };
         | 
| 483 | 
            -
             | 
| 484 | 
            -
            } // namespace flash
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash.h
    DELETED
    
    | @@ -1,220 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2023, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include <cuda.h>
         | 
| 8 | 
            -
            #include <vector>
         | 
| 9 | 
            -
             | 
| 10 | 
            -
            ////////////////////////////////////////////////////////////////////////////////////////////////////
         | 
| 11 | 
            -
             | 
| 12 | 
            -
            struct Qkv_params {
         | 
| 13 | 
            -
                using index_t = int64_t;
         | 
| 14 | 
            -
                // The QKV matrices.
         | 
| 15 | 
            -
                void *__restrict__ q_ptr;
         | 
| 16 | 
            -
                void *__restrict__ k_ptr;
         | 
| 17 | 
            -
                void *__restrict__ v_ptr;
         | 
| 18 | 
            -
             | 
| 19 | 
            -
                // The stride between rows of the Q, K and V matrices.
         | 
| 20 | 
            -
                index_t q_batch_stride;
         | 
| 21 | 
            -
                index_t k_batch_stride;
         | 
| 22 | 
            -
                index_t v_batch_stride;
         | 
| 23 | 
            -
                index_t q_row_stride;
         | 
| 24 | 
            -
                index_t k_row_stride;
         | 
| 25 | 
            -
                index_t v_row_stride;
         | 
| 26 | 
            -
                index_t q_head_stride;
         | 
| 27 | 
            -
                index_t k_head_stride;
         | 
| 28 | 
            -
                index_t v_head_stride;
         | 
| 29 | 
            -
                index_t v_dim_stride;
         | 
| 30 | 
            -
             | 
| 31 | 
            -
                // The number of heads.
         | 
| 32 | 
            -
                int h, h_k;
         | 
| 33 | 
            -
            };
         | 
| 34 | 
            -
             | 
| 35 | 
            -
            ////////////////////////////////////////////////////////////////////////////////////////////////////
         | 
| 36 | 
            -
             | 
| 37 | 
            -
            struct Flash_fwd_params : public Qkv_params {
         | 
| 38 | 
            -
                using index_t = int64_t;
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                // The O matrix (output).
         | 
| 41 | 
            -
                void * __restrict__ o_ptr;
         | 
| 42 | 
            -
                void * __restrict__ oaccum_ptr;
         | 
| 43 | 
            -
             | 
| 44 | 
            -
                // The stride between rows of O.
         | 
| 45 | 
            -
                index_t o_batch_stride;
         | 
| 46 | 
            -
                index_t o_row_stride;
         | 
| 47 | 
            -
                index_t o_head_stride;
         | 
| 48 | 
            -
             | 
| 49 | 
            -
                // The pointer to the softmax sum.
         | 
| 50 | 
            -
                void * __restrict__ softmax_lse_ptr;
         | 
| 51 | 
            -
                void * __restrict__ softmax_lseaccum_ptr;
         | 
| 52 | 
            -
             | 
| 53 | 
            -
                // For FP8 scaling
         | 
| 54 | 
            -
                float * __restrict__ q_descale_ptr;
         | 
| 55 | 
            -
                float * __restrict__ k_descale_ptr;
         | 
| 56 | 
            -
                float * __restrict__ v_descale_ptr;
         | 
| 57 | 
            -
                index_t q_descale_batch_stride;
         | 
| 58 | 
            -
                index_t q_descale_head_stride;
         | 
| 59 | 
            -
                index_t k_descale_batch_stride;
         | 
| 60 | 
            -
                index_t k_descale_head_stride;
         | 
| 61 | 
            -
                index_t v_descale_batch_stride;
         | 
| 62 | 
            -
                index_t v_descale_head_stride;
         | 
| 63 | 
            -
             | 
| 64 | 
            -
                // The dimensions.
         | 
| 65 | 
            -
                int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
         | 
| 66 | 
            -
                int total_q, total_k, total_knew;
         | 
| 67 | 
            -
                int b_k;  // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q
         | 
| 68 | 
            -
                int dv, dv_rounded;  // For the case where V headdim is different from Q/K headdim
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                // The scaling factors for the kernel.
         | 
| 71 | 
            -
                float scale_softmax;
         | 
| 72 | 
            -
                float softcap;
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                // array of length b+1 holding starting offset of each sequence.
         | 
| 75 | 
            -
                int * __restrict__ cu_seqlens_q;
         | 
| 76 | 
            -
                int * __restrict__ cu_seqlens_k;
         | 
| 77 | 
            -
                int * __restrict__ cu_seqlens_knew;
         | 
| 78 | 
            -
                int * __restrict__ leftpad_k;
         | 
| 79 | 
            -
             | 
| 80 | 
            -
                // If provided, the actual length of each q/k sequence.
         | 
| 81 | 
            -
                int *__restrict__ seqused_q;
         | 
| 82 | 
            -
                int *__restrict__ seqused_k;
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                // The stride between rows of Oaccum.
         | 
| 85 | 
            -
                index_t oaccum_split_stride;
         | 
| 86 | 
            -
                index_t oaccum_batch_stride;
         | 
| 87 | 
            -
                index_t oaccum_row_stride;
         | 
| 88 | 
            -
                index_t oaccum_head_stride;
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                // The stride between rows of LSEaccum.
         | 
| 91 | 
            -
                index_t lseaccum_split_stride;
         | 
| 92 | 
            -
                index_t lseaccum_batch_stride;
         | 
| 93 | 
            -
                index_t lseaccum_head_stride;
         | 
| 94 | 
            -
             | 
| 95 | 
            -
                // The K_new and V_new matrices.
         | 
| 96 | 
            -
                void * __restrict__ knew_ptr;
         | 
| 97 | 
            -
                void * __restrict__ vnew_ptr;
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                // The stride between rows of the Q, K and V matrices.
         | 
| 100 | 
            -
                index_t knew_batch_stride;
         | 
| 101 | 
            -
                index_t vnew_batch_stride;
         | 
| 102 | 
            -
                index_t knew_row_stride;
         | 
| 103 | 
            -
                index_t vnew_row_stride;
         | 
| 104 | 
            -
                index_t knew_head_stride;
         | 
| 105 | 
            -
                index_t vnew_head_stride;
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                void *__restrict__ qv_ptr;
         | 
| 108 | 
            -
                index_t qv_batch_stride;
         | 
| 109 | 
            -
                index_t qv_row_stride;
         | 
| 110 | 
            -
                index_t qv_head_stride;
         | 
| 111 | 
            -
             | 
| 112 | 
            -
                // The cos and sin matrices for rotary embedding.
         | 
| 113 | 
            -
                void * __restrict__ rotary_cos_ptr;
         | 
| 114 | 
            -
                void * __restrict__ rotary_sin_ptr;
         | 
| 115 | 
            -
                int *__restrict__ seqlens_rotary;
         | 
| 116 | 
            -
             | 
| 117 | 
            -
                // The indices to index into the KV cache.
         | 
| 118 | 
            -
                int * __restrict__ kv_batch_idx;
         | 
| 119 | 
            -
             | 
| 120 | 
            -
                // Paged KV cache
         | 
| 121 | 
            -
                int * __restrict__ page_table;
         | 
| 122 | 
            -
                index_t page_table_batch_stride;
         | 
| 123 | 
            -
                int page_size;
         | 
| 124 | 
            -
                int num_pages;
         | 
| 125 | 
            -
                bool pagedkv_tma;
         | 
| 126 | 
            -
             | 
| 127 | 
            -
                // The dropout probability (probability of keeping an activation).
         | 
| 128 | 
            -
                float p_dropout;
         | 
| 129 | 
            -
                // uint32_t p_dropout_in_uint;
         | 
| 130 | 
            -
                // uint16_t p_dropout_in_uint16_t;
         | 
| 131 | 
            -
                uint8_t p_dropout_in_uint8_t;
         | 
| 132 | 
            -
             | 
| 133 | 
            -
                // Scale factor of 1 / (1 - p_dropout).
         | 
| 134 | 
            -
                float rp_dropout;
         | 
| 135 | 
            -
             | 
| 136 | 
            -
                // Local window size
         | 
| 137 | 
            -
                int window_size_left, window_size_right;
         | 
| 138 | 
            -
             | 
| 139 | 
            -
                // Pointer to the RNG seed (idx 0) and offset (idx 1).
         | 
| 140 | 
            -
                uint64_t * rng_state;
         | 
| 141 | 
            -
             | 
| 142 | 
            -
                bool is_bf16;
         | 
| 143 | 
            -
                bool is_fp32;
         | 
| 144 | 
            -
                bool is_e4m3;
         | 
| 145 | 
            -
                bool is_causal;
         | 
| 146 | 
            -
                bool is_local;
         | 
| 147 | 
            -
             | 
| 148 | 
            -
                bool is_rotary_interleaved;
         | 
| 149 | 
            -
             | 
| 150 | 
            -
                int num_splits;  // For split-KV version
         | 
| 151 | 
            -
                bool pack_gqa;
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                int * __restrict__ tile_count_semaphore;
         | 
| 154 | 
            -
                // int * __restrict__ num_m_blocks_ptr;
         | 
| 155 | 
            -
                // int * __restrict__ num_n_blocks_ptr;
         | 
| 156 | 
            -
                int * __restrict__ num_splits_dynamic_ptr;
         | 
| 157 | 
            -
                bool skip_scheduler_metadata_computation;
         | 
| 158 | 
            -
             | 
| 159 | 
            -
                int arch;
         | 
| 160 | 
            -
                int num_sm;
         | 
| 161 | 
            -
             | 
| 162 | 
            -
                // The S extra matrix, (num_heads)
         | 
| 163 | 
            -
                void *__restrict__ s_aux_ptr;
         | 
| 164 | 
            -
            };
         | 
| 165 | 
            -
             | 
| 166 | 
            -
            ////////////////////////////////////////////////////////////////////////////////////////////////////
         | 
| 167 | 
            -
             | 
| 168 | 
            -
            struct Flash_bwd_params : public Flash_fwd_params {
         | 
| 169 | 
            -
                using index_t = int64_t;
         | 
| 170 | 
            -
             | 
| 171 | 
            -
                // The dO and dQKV matrices.
         | 
| 172 | 
            -
                void *__restrict__ do_ptr;
         | 
| 173 | 
            -
                void *__restrict__ dq_ptr;
         | 
| 174 | 
            -
                void *__restrict__ dk_ptr;
         | 
| 175 | 
            -
                void *__restrict__ dv_ptr;
         | 
| 176 | 
            -
             | 
| 177 | 
            -
                // To accumulate dQ
         | 
| 178 | 
            -
                void *__restrict__ dq_accum_ptr;
         | 
| 179 | 
            -
                void *__restrict__ dk_accum_ptr;
         | 
| 180 | 
            -
                void *__restrict__ dv_accum_ptr;
         | 
| 181 | 
            -
             | 
| 182 | 
            -
                // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
         | 
| 183 | 
            -
                // dimension void *__restrict__ dk_accum_ptr; void *__restrict__
         | 
| 184 | 
            -
                // dv_accum_ptr;
         | 
| 185 | 
            -
             | 
| 186 | 
            -
                // The stride between rows of the dO, dQ, dK and dV matrices.
         | 
| 187 | 
            -
                index_t do_batch_stride;
         | 
| 188 | 
            -
                index_t do_row_stride;
         | 
| 189 | 
            -
                index_t do_head_stride;
         | 
| 190 | 
            -
                index_t dq_batch_stride;
         | 
| 191 | 
            -
                index_t dk_batch_stride;
         | 
| 192 | 
            -
                index_t dv_batch_stride;
         | 
| 193 | 
            -
                index_t dq_row_stride;
         | 
| 194 | 
            -
                index_t dk_row_stride;
         | 
| 195 | 
            -
                index_t dv_row_stride;
         | 
| 196 | 
            -
                index_t dq_head_stride;
         | 
| 197 | 
            -
                index_t dk_head_stride;
         | 
| 198 | 
            -
                index_t dv_head_stride;
         | 
| 199 | 
            -
             | 
| 200 | 
            -
                // The pointer to the softmax d sum.
         | 
| 201 | 
            -
                void *__restrict__ dsoftmax_sum;
         | 
| 202 | 
            -
                void *__restrict__ softmax_lse_log2_ptr;
         | 
| 203 | 
            -
             | 
| 204 | 
            -
                int *__restrict__ dq_semaphore;
         | 
| 205 | 
            -
                int *__restrict__ dk_semaphore;
         | 
| 206 | 
            -
                int *__restrict__ dv_semaphore;
         | 
| 207 | 
            -
             | 
| 208 | 
            -
                bool deterministic;
         | 
| 209 | 
            -
                index_t dq_accum_split_stride;
         | 
| 210 | 
            -
            };
         | 
| 211 | 
            -
             | 
| 212 | 
            -
            ////////////////////////////////////////////////////////////////////////////////////////////////////
         | 
| 213 | 
            -
             | 
| 214 | 
            -
            template <int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
         | 
| 215 | 
            -
            void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
         | 
| 216 | 
            -
            void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl);
         | 
| 217 | 
            -
            template <int Arch, typename T, int kHeadDim, bool Has_softcap>
         | 
| 218 | 
            -
            void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
         | 
| 219 | 
            -
            template <typename T, typename Tpartial, int kBlockK>
         | 
| 220 | 
            -
            void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash_api.cpp
    DELETED
    
    | @@ -1,1623 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
         | 
| 6 | 
            -
            #include <torch/nn/functional.h>
         | 
| 7 | 
            -
            #include <torch/version.h>  // For TORCH_VERSION* macros
         | 
| 8 | 
            -
            #include <ATen/cuda/CUDAContext.h>
         | 
| 9 | 
            -
            #include <c10/cuda/CUDAGuard.h>
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            #include <cutlass/numeric_types.h>
         | 
| 12 | 
            -
             | 
| 13 | 
            -
            #include "flash.h"
         | 
| 14 | 
            -
            #include "static_switch.h"
         | 
| 15 | 
            -
            #include "tile_size.h"
         | 
| 16 | 
            -
            #include "heuristics.h"
         | 
| 17 | 
            -
            #include "cuda_check.h"
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            // Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909
         | 
| 20 | 
            -
            // This is so that we can pass in torch.dtype as a parameter to the function.
         | 
| 21 | 
            -
            #if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4)
         | 
| 22 | 
            -
             | 
| 23 | 
            -
            #include <pybind11/pybind11.h>
         | 
| 24 | 
            -
            #include <pybind11/stl.h>
         | 
| 25 | 
            -
             | 
| 26 | 
            -
            namespace pybind11::detail {
         | 
| 27 | 
            -
             | 
| 28 | 
            -
                template <>
         | 
| 29 | 
            -
                struct type_caster<at::ScalarType> {
         | 
| 30 | 
            -
                public:
         | 
| 31 | 
            -
                    // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
         | 
| 32 | 
            -
                    PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype"));
         | 
| 33 | 
            -
                    // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType
         | 
| 34 | 
            -
                    // cannot be default-initialized, we provide this constructor to explicitly
         | 
| 35 | 
            -
                    // initialize that field. The value doesn't matter as it will be overwritten
         | 
| 36 | 
            -
                    // after a successful call to load.
         | 
| 37 | 
            -
                    type_caster() : value(at::kFloat) {}
         | 
| 38 | 
            -
                    bool load(handle src, bool) {
         | 
| 39 | 
            -
                        PyObject* obj = src.ptr();
         | 
| 40 | 
            -
                        if (THPDtype_Check(obj)) {
         | 
| 41 | 
            -
                            value = reinterpret_cast<THPDtype*>(obj)->scalar_type;
         | 
| 42 | 
            -
                            return true;
         | 
| 43 | 
            -
                        }
         | 
| 44 | 
            -
                        return false;
         | 
| 45 | 
            -
                    }
         | 
| 46 | 
            -
                    static handle cast(
         | 
| 47 | 
            -
                                       const at::ScalarType& src,
         | 
| 48 | 
            -
                                       return_value_policy /* policy */,
         | 
| 49 | 
            -
                                       handle /* parent */) {
         | 
| 50 | 
            -
                        return Py_NewRef(torch::getTHPDtype(src));
         | 
| 51 | 
            -
                    }
         | 
| 52 | 
            -
                };
         | 
| 53 | 
            -
             | 
| 54 | 
            -
            } // namespace pybind11::detail
         | 
| 55 | 
            -
             | 
| 56 | 
            -
            #endif
         | 
| 57 | 
            -
             | 
| 58 | 
            -
            #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
         | 
| 59 | 
            -
            #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
         | 
| 60 | 
            -
            #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
         | 
| 61 | 
            -
             | 
| 62 | 
            -
            void set_params_fprop(Flash_fwd_params ¶ms,
         | 
| 63 | 
            -
                                  // sizes
         | 
| 64 | 
            -
                                  const size_t b,
         | 
| 65 | 
            -
                                  const size_t seqlen_q,
         | 
| 66 | 
            -
                                  const size_t seqlen_k,
         | 
| 67 | 
            -
                                  const size_t seqlen_q_rounded,
         | 
| 68 | 
            -
                                  const size_t seqlen_k_rounded,
         | 
| 69 | 
            -
                                  const size_t h,
         | 
| 70 | 
            -
                                  const size_t h_k,
         | 
| 71 | 
            -
                                  const size_t d,
         | 
| 72 | 
            -
                                  const size_t d_rounded,
         | 
| 73 | 
            -
                                  // device pointers
         | 
| 74 | 
            -
                                  const at::Tensor q,
         | 
| 75 | 
            -
                                  const at::Tensor k,
         | 
| 76 | 
            -
                                  const at::Tensor v,
         | 
| 77 | 
            -
                                  at::Tensor out,
         | 
| 78 | 
            -
                                  void *cu_seqlens_q_d,
         | 
| 79 | 
            -
                                  void *cu_seqlens_k_d,
         | 
| 80 | 
            -
                                  void *seqused_q,
         | 
| 81 | 
            -
                                  void *seqused_k,
         | 
| 82 | 
            -
                                  void *softmax_lse_d,
         | 
| 83 | 
            -
                                  float p_dropout,
         | 
| 84 | 
            -
                                  float softmax_scale,
         | 
| 85 | 
            -
                                  int window_size_left,
         | 
| 86 | 
            -
                                  int window_size_right,
         | 
| 87 | 
            -
                                  const float softcap=0.f,
         | 
| 88 | 
            -
                                  const int sm_margin=0) {
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                // Reset the parameters
         | 
| 91 | 
            -
                params = {};
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                params.is_bf16 = q.dtype() == torch::kBFloat16;
         | 
| 94 | 
            -
                params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
         | 
| 95 | 
            -
             | 
| 96 | 
            -
                // Set the pointers and strides.
         | 
| 97 | 
            -
                params.q_ptr = q.data_ptr();
         | 
| 98 | 
            -
                params.k_ptr = k.data_ptr();
         | 
| 99 | 
            -
                params.v_ptr = v.data_ptr();
         | 
| 100 | 
            -
                // All stride are in elements, not bytes.
         | 
| 101 | 
            -
                params.q_row_stride = q.stride(-3);
         | 
| 102 | 
            -
                params.k_row_stride = k.stride(-3);
         | 
| 103 | 
            -
                params.v_row_stride = v.stride(-3);
         | 
| 104 | 
            -
                params.q_head_stride = q.stride(-2);
         | 
| 105 | 
            -
                params.k_head_stride = k.stride(-2);
         | 
| 106 | 
            -
                params.v_head_stride = v.stride(-2);
         | 
| 107 | 
            -
                params.v_dim_stride = v.stride(-1);
         | 
| 108 | 
            -
                params.o_ptr = out.data_ptr();
         | 
| 109 | 
            -
                params.o_row_stride = out.stride(-3);
         | 
| 110 | 
            -
                params.o_head_stride = out.stride(-2);
         | 
| 111 | 
            -
             | 
| 112 | 
            -
                if (cu_seqlens_q_d == nullptr) {
         | 
| 113 | 
            -
                    params.q_batch_stride = q.stride(0);
         | 
| 114 | 
            -
                    params.o_batch_stride = out.stride(0);
         | 
| 115 | 
            -
                }
         | 
| 116 | 
            -
                if (cu_seqlens_k_d == nullptr) {
         | 
| 117 | 
            -
                    params.k_batch_stride = k.stride(0);
         | 
| 118 | 
            -
                    params.v_batch_stride = v.stride(0);
         | 
| 119 | 
            -
                }
         | 
| 120 | 
            -
             | 
| 121 | 
            -
                params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
         | 
| 122 | 
            -
                params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
         | 
| 123 | 
            -
                params.seqused_q = static_cast<int *>(seqused_q);
         | 
| 124 | 
            -
                params.seqused_k = static_cast<int *>(seqused_k);
         | 
| 125 | 
            -
             | 
| 126 | 
            -
                // Softmax sum
         | 
| 127 | 
            -
                params.softmax_lse_ptr = softmax_lse_d;
         | 
| 128 | 
            -
             | 
| 129 | 
            -
                // Set the dimensions.
         | 
| 130 | 
            -
                params.b = b;
         | 
| 131 | 
            -
                params.h = h;
         | 
| 132 | 
            -
                params.h_k = h_k;
         | 
| 133 | 
            -
                params.seqlen_q = seqlen_q;
         | 
| 134 | 
            -
                params.seqlen_k = seqlen_k;
         | 
| 135 | 
            -
                params.seqlen_q_rounded = seqlen_q_rounded;
         | 
| 136 | 
            -
                params.seqlen_k_rounded = seqlen_k_rounded;
         | 
| 137 | 
            -
                params.d = d;
         | 
| 138 | 
            -
                params.d_rounded = d_rounded;
         | 
| 139 | 
            -
             | 
| 140 | 
            -
                // Set the different scale values.
         | 
| 141 | 
            -
                params.scale_softmax = softmax_scale;
         | 
| 142 | 
            -
                params.softcap = softcap;
         | 
| 143 | 
            -
             | 
| 144 | 
            -
                // Set this to probability of keeping an element to simplify things.
         | 
| 145 | 
            -
                params.p_dropout = 1.f - p_dropout;
         | 
| 146 | 
            -
                // Convert p from float to int so we don't have to convert the random uint to float to compare.
         | 
| 147 | 
            -
                // [Minor] We want to round down since when we do the comparison we use <= instead of <
         | 
| 148 | 
            -
                // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
         | 
| 149 | 
            -
                // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
         | 
| 150 | 
            -
                params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
         | 
| 151 | 
            -
                params.rp_dropout = 1.f / params.p_dropout;
         | 
| 152 | 
            -
                TORCH_CHECK(p_dropout < 1.f);
         | 
| 153 | 
            -
                #ifdef FLASHATTENTION_DISABLE_DROPOUT
         | 
| 154 | 
            -
                    TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
         | 
| 155 | 
            -
                #endif
         | 
| 156 | 
            -
             | 
| 157 | 
            -
                // Causal is the special case where window_size_right == 0 and window_size_left < 0.
         | 
| 158 | 
            -
                // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
         | 
| 159 | 
            -
                params.is_causal = window_size_left < 0 && window_size_right == 0;
         | 
| 160 | 
            -
                params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
         | 
| 161 | 
            -
             | 
| 162 | 
            -
                // TODO: check this
         | 
| 163 | 
            -
                if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k - 1; }
         | 
| 164 | 
            -
                if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_q - 1; }
         | 
| 165 | 
            -
                params.window_size_left = window_size_left;
         | 
| 166 | 
            -
                params.window_size_right = window_size_right;
         | 
| 167 | 
            -
             | 
| 168 | 
            -
                params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
         | 
| 169 | 
            -
                params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
         | 
| 170 | 
            -
             | 
| 171 | 
            -
                #ifdef FLASHATTENTION_DISABLE_LOCAL
         | 
| 172 | 
            -
                    TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
         | 
| 173 | 
            -
                #endif
         | 
| 174 | 
            -
            }
         | 
| 175 | 
            -
             | 
| 176 | 
            -
            void set_params_dgrad(Flash_bwd_params ¶ms,
         | 
| 177 | 
            -
                                  // sizes
         | 
| 178 | 
            -
                                  const size_t b,
         | 
| 179 | 
            -
                                  const size_t seqlen_q,
         | 
| 180 | 
            -
                                  const size_t seqlen_k,
         | 
| 181 | 
            -
                                  const size_t seqlen_q_rounded,
         | 
| 182 | 
            -
                                  const size_t seqlen_k_rounded,
         | 
| 183 | 
            -
                                  const size_t h,
         | 
| 184 | 
            -
                                  const size_t h_k,
         | 
| 185 | 
            -
                                  const size_t d,
         | 
| 186 | 
            -
                                  const size_t d_rounded,
         | 
| 187 | 
            -
                                  // device pointers
         | 
| 188 | 
            -
                                  const at::Tensor q,
         | 
| 189 | 
            -
                                  const at::Tensor k,
         | 
| 190 | 
            -
                                  const at::Tensor v,
         | 
| 191 | 
            -
                                  const at::Tensor out,
         | 
| 192 | 
            -
                                  const at::Tensor dout,
         | 
| 193 | 
            -
                                  at::Tensor dq,
         | 
| 194 | 
            -
                                  at::Tensor dk,
         | 
| 195 | 
            -
                                  at::Tensor dv,
         | 
| 196 | 
            -
                                  void *cu_seqlens_q_d,
         | 
| 197 | 
            -
                                  void *cu_seqlens_k_d,
         | 
| 198 | 
            -
                                  void *seqused_q,
         | 
| 199 | 
            -
                                  void *seqused_k,
         | 
| 200 | 
            -
                                  void *dq_accum_d,
         | 
| 201 | 
            -
                                  void *dk_accum_d,
         | 
| 202 | 
            -
                                  void *dv_accum_d,
         | 
| 203 | 
            -
                                  void *softmax_lse_d,
         | 
| 204 | 
            -
                                  void *dsoftmax_sum_d,
         | 
| 205 | 
            -
                                  float p_dropout,
         | 
| 206 | 
            -
                                  float softmax_scale,
         | 
| 207 | 
            -
                                  int window_size_left,
         | 
| 208 | 
            -
                                  int window_size_right,
         | 
| 209 | 
            -
                                  const float softcap=0.f,
         | 
| 210 | 
            -
                                  bool deterministic=false,
         | 
| 211 | 
            -
                                  int const sm_margin=0) {
         | 
| 212 | 
            -
             | 
| 213 | 
            -
                set_params_fprop(params,
         | 
| 214 | 
            -
                                 b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
         | 
| 215 | 
            -
                                 q, k, v, out,
         | 
| 216 | 
            -
                                 cu_seqlens_q_d,
         | 
| 217 | 
            -
                                 cu_seqlens_k_d,
         | 
| 218 | 
            -
                                 seqused_q,
         | 
| 219 | 
            -
                                 seqused_k,
         | 
| 220 | 
            -
                                 softmax_lse_d,
         | 
| 221 | 
            -
                                 p_dropout,
         | 
| 222 | 
            -
                                 softmax_scale,
         | 
| 223 | 
            -
                                 window_size_left,
         | 
| 224 | 
            -
                                 window_size_right,
         | 
| 225 | 
            -
                                 softcap,
         | 
| 226 | 
            -
                                 sm_margin);
         | 
| 227 | 
            -
             | 
| 228 | 
            -
                // Set the pointers and strides.
         | 
| 229 | 
            -
                params.do_ptr = dout.data_ptr();
         | 
| 230 | 
            -
                params.do_row_stride = dout.stride(-3);
         | 
| 231 | 
            -
                params.do_head_stride = dout.stride(-2);
         | 
| 232 | 
            -
                params.dq_ptr = dq.data_ptr();
         | 
| 233 | 
            -
                params.dk_ptr = dk.data_ptr();
         | 
| 234 | 
            -
                params.dv_ptr = dv.data_ptr();
         | 
| 235 | 
            -
                params.dq_row_stride = dq.stride(-3);
         | 
| 236 | 
            -
                params.dk_row_stride = dk.stride(-3);
         | 
| 237 | 
            -
                params.dv_row_stride = dv.stride(-3);
         | 
| 238 | 
            -
                params.dq_head_stride = dq.stride(-2);
         | 
| 239 | 
            -
                params.dk_head_stride = dk.stride(-2);
         | 
| 240 | 
            -
                params.dv_head_stride = dv.stride(-2);
         | 
| 241 | 
            -
             | 
| 242 | 
            -
                if (cu_seqlens_q_d == nullptr) {
         | 
| 243 | 
            -
                    params.do_batch_stride = dout.stride(0);
         | 
| 244 | 
            -
                    params.dq_batch_stride = dq.stride(0);
         | 
| 245 | 
            -
                    params.dk_batch_stride = dk.stride(0);
         | 
| 246 | 
            -
                    params.dv_batch_stride = dv.stride(0);
         | 
| 247 | 
            -
                }
         | 
| 248 | 
            -
             | 
| 249 | 
            -
                params.dq_accum_ptr = dq_accum_d;
         | 
| 250 | 
            -
                params.dk_accum_ptr = dk_accum_d;
         | 
| 251 | 
            -
                params.dv_accum_ptr = dv_accum_d;
         | 
| 252 | 
            -
             | 
| 253 | 
            -
                // Softmax sum
         | 
| 254 | 
            -
                params.dsoftmax_sum = dsoftmax_sum_d;
         | 
| 255 | 
            -
             | 
| 256 | 
            -
                params.deterministic = deterministic;
         | 
| 257 | 
            -
            }
         | 
| 258 | 
            -
             | 
| 259 | 
            -
            void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
         | 
| 260 | 
            -
                // HEADDIM_SWITCH(params.d, [&] {
         | 
| 261 | 
            -
                //     run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
         | 
| 262 | 
            -
                // });
         | 
| 263 | 
            -
                TORCH_CHECK(params.num_splits >= 1);
         | 
| 264 | 
            -
                ARCH_SWITCH(params.arch, Arch, [&] {
         | 
| 265 | 
            -
                    SPLIT_SWITCH(params.num_splits > 1, Split, [&] {
         | 
| 266 | 
            -
                        PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] {
         | 
| 267 | 
            -
                            PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] {
         | 
| 268 | 
            -
                                // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation
         | 
| 269 | 
            -
                                static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split;
         | 
| 270 | 
            -
                                SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
         | 
| 271 | 
            -
                                    if (!params.is_e4m3) {
         | 
| 272 | 
            -
                                        if (params.is_bf16) {
         | 
| 273 | 
            -
                                            #ifndef FLASHATTENTION_DISABLE_HDIM64
         | 
| 274 | 
            -
                                            if (params.d <= 64) {
         | 
| 275 | 
            -
                                                if (params.dv > 256 && Arch == 90) {
         | 
| 276 | 
            -
                                                    return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
         | 
| 277 | 
            -
                                                } else if (params.dv > 64 && Arch == 90) {
         | 
| 278 | 
            -
                                                    return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
         | 
| 279 | 
            -
                                                } else {
         | 
| 280 | 
            -
                                                    return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
         | 
| 281 | 
            -
                                                }
         | 
| 282 | 
            -
                                            }
         | 
| 283 | 
            -
                                            #endif
         | 
| 284 | 
            -
                                            #ifndef FLASHATTENTION_DISABLE_HDIM96
         | 
| 285 | 
            -
                                            if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
         | 
| 286 | 
            -
                                            #endif
         | 
| 287 | 
            -
                                            #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 288 | 
            -
                                            if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
         | 
| 289 | 
            -
                                            #endif
         | 
| 290 | 
            -
                                            #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 291 | 
            -
                                            if (params.d <= 192) {
         | 
| 292 | 
            -
                                                if (params.dv <= 128 && Arch == 90) {
         | 
| 293 | 
            -
                                                    return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
         | 
| 294 | 
            -
                                                } else {
         | 
| 295 | 
            -
                                                    return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
         | 
| 296 | 
            -
                                                }
         | 
| 297 | 
            -
                                            }
         | 
| 298 | 
            -
                                            #endif
         | 
| 299 | 
            -
                                            #ifndef FLASHATTENTION_DISABLE_HDIM256
         | 
| 300 | 
            -
                                            if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
         | 
| 301 | 
            -
                                            #endif
         | 
| 302 | 
            -
                                        } else {
         | 
| 303 | 
            -
                                            #ifndef FLASHATTENTION_DISABLE_FP16
         | 
| 304 | 
            -
                                            #ifndef FLASHATTENTION_DISABLE_HDIM64
         | 
| 305 | 
            -
                                            if (params.d <= 64) {
         | 
| 306 | 
            -
                                                if (params.dv > 256 && Arch == 90) {
         | 
| 307 | 
            -
                                                    return run_mha_fwd_<Arch, cutlass::half_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
         | 
| 308 | 
            -
                                                } else if (params.dv > 64 && Arch == 90) {
         | 
| 309 | 
            -
                                                    return run_mha_fwd_<Arch, cutlass::half_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
         | 
| 310 | 
            -
                                                } else {
         | 
| 311 | 
            -
                                                    return run_mha_fwd_<Arch, cutlass::half_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
         | 
| 312 | 
            -
                                                }
         | 
| 313 | 
            -
                                            }
         | 
| 314 | 
            -
                                            #endif
         | 
| 315 | 
            -
                                            #ifndef FLASHATTENTION_DISABLE_HDIM96
         | 
| 316 | 
            -
                                            if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::half_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
         | 
| 317 | 
            -
                                            #endif
         | 
| 318 | 
            -
                                            #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 319 | 
            -
                                            if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::half_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
         | 
| 320 | 
            -
                                            #endif
         | 
| 321 | 
            -
                                            #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 322 | 
            -
                                            if (params.d <= 192) {
         | 
| 323 | 
            -
                                                if (params.dv <= 128 && Arch == 90) {
         | 
| 324 | 
            -
                                                    return run_mha_fwd_<Arch, cutlass::half_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
         | 
| 325 | 
            -
                                                } else {
         | 
| 326 | 
            -
                                                    return run_mha_fwd_<Arch, cutlass::half_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
         | 
| 327 | 
            -
                                                }
         | 
| 328 | 
            -
                                            }
         | 
| 329 | 
            -
                                            #endif
         | 
| 330 | 
            -
                                            #ifndef FLASHATTENTION_DISABLE_HDIM256
         | 
| 331 | 
            -
                                            if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::half_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
         | 
| 332 | 
            -
                                            #endif
         | 
| 333 | 
            -
                                            #else
         | 
| 334 | 
            -
                                            TORCH_CHECK(false, "This flash attention build does not support FP16.");
         | 
| 335 | 
            -
                                            #endif
         | 
| 336 | 
            -
                                        }
         | 
| 337 | 
            -
                                    } else {
         | 
| 338 | 
            -
                                        #ifndef FLASHATTENTION_DISABLE_FP8
         | 
| 339 | 
            -
                                        #ifndef FLASHATTENTION_DISABLE_HDIM64
         | 
| 340 | 
            -
                                        if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
         | 
| 341 | 
            -
                                        #endif
         | 
| 342 | 
            -
                                        #ifndef FLASHATTENTION_DISABLE_HDIM96
         | 
| 343 | 
            -
                                        if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
         | 
| 344 | 
            -
                                        #endif
         | 
| 345 | 
            -
                                        #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 346 | 
            -
                                        if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
         | 
| 347 | 
            -
                                        #endif
         | 
| 348 | 
            -
                                        #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 349 | 
            -
                                        if (params.d <= 192) {
         | 
| 350 | 
            -
                                            if (params.dv <= 128 && Arch == 90) {
         | 
| 351 | 
            -
                                                return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
         | 
| 352 | 
            -
                                            } else {
         | 
| 353 | 
            -
                                                return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
         | 
| 354 | 
            -
                                            }
         | 
| 355 | 
            -
                                        }
         | 
| 356 | 
            -
                                        #endif
         | 
| 357 | 
            -
                                        #ifndef FLASHATTENTION_DISABLE_HDIM256
         | 
| 358 | 
            -
                                        if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
         | 
| 359 | 
            -
                                        #endif
         | 
| 360 | 
            -
                                        #else
         | 
| 361 | 
            -
                                        TORCH_CHECK(false, "This flash attention build does not support FP8.");
         | 
| 362 | 
            -
                                        #endif
         | 
| 363 | 
            -
                                    }
         | 
| 364 | 
            -
                                });
         | 
| 365 | 
            -
                            });
         | 
| 366 | 
            -
                        });
         | 
| 367 | 
            -
                    });
         | 
| 368 | 
            -
                });
         | 
| 369 | 
            -
            }
         | 
| 370 | 
            -
             | 
| 371 | 
            -
            void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) {
         | 
| 372 | 
            -
                #ifndef FLASHATTENTION_DISABLE_SPLIT
         | 
| 373 | 
            -
                // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
         | 
| 374 | 
            -
                // so that kBlockM is smaller and we have more parallelism.
         | 
| 375 | 
            -
                if (params.is_fp32) {
         | 
| 376 | 
            -
                    if (params.dv <= 64) {
         | 
| 377 | 
            -
                        run_mha_fwd_combine_<float, float, 64>(params, stream, enable_pdl);
         | 
| 378 | 
            -
                    } else {
         | 
| 379 | 
            -
                        run_mha_fwd_combine_<float, float, 128>(params, stream, enable_pdl);
         | 
| 380 | 
            -
                    }
         | 
| 381 | 
            -
                } else if (params.is_bf16) {
         | 
| 382 | 
            -
                    if (params.dv <= 64) {
         | 
| 383 | 
            -
                        run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(params, stream, enable_pdl);
         | 
| 384 | 
            -
                    } else {
         | 
| 385 | 
            -
                        run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(params, stream, enable_pdl);
         | 
| 386 | 
            -
                    }
         | 
| 387 | 
            -
                } else {
         | 
| 388 | 
            -
                    if (params.dv <= 64) {
         | 
| 389 | 
            -
                        run_mha_fwd_combine_<cutlass::half_t, float, 64>(params, stream, enable_pdl);
         | 
| 390 | 
            -
                    } else {
         | 
| 391 | 
            -
                        run_mha_fwd_combine_<cutlass::half_t, float, 128>(params, stream, enable_pdl);
         | 
| 392 | 
            -
                    }
         | 
| 393 | 
            -
                }
         | 
| 394 | 
            -
                #else
         | 
| 395 | 
            -
                TORCH_CHECK(false, "This flash attention build does not support combine kernels.");
         | 
| 396 | 
            -
                #endif
         | 
| 397 | 
            -
            }
         | 
| 398 | 
            -
             | 
| 399 | 
            -
            inline bool get_pagedkv_tma(Flash_fwd_params const& params) {
         | 
| 400 | 
            -
                if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; }
         | 
| 401 | 
            -
                // This needs to match the kernel configs
         | 
| 402 | 
            -
                auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f);
         | 
| 403 | 
            -
                int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
         | 
| 404 | 
            -
                int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90);
         | 
| 405 | 
            -
                // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower,
         | 
| 406 | 
            -
                // at least for MLA.
         | 
| 407 | 
            -
                return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM;
         | 
| 408 | 
            -
            }
         | 
| 409 | 
            -
             | 
| 410 | 
            -
            inline bool get_pack_gqa(Flash_fwd_params const& params) {
         | 
| 411 | 
            -
                // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size.
         | 
| 412 | 
            -
                // Has little effect on speed.
         | 
| 413 | 
            -
                if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; }
         | 
| 414 | 
            -
                #ifdef FLASHATTENTION_DISABLE_PACKGQA
         | 
| 415 | 
            -
                return false;
         | 
| 416 | 
            -
                #else
         | 
| 417 | 
            -
                // params.page_table must already be set
         | 
| 418 | 
            -
                if (params.h == params.h_k) { return false; }
         | 
| 419 | 
            -
                // This needs to match the kernel configs
         | 
| 420 | 
            -
                auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
         | 
| 421 | 
            -
                int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
         | 
| 422 | 
            -
                return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);
         | 
| 423 | 
            -
                #endif
         | 
| 424 | 
            -
            }
         | 
| 425 | 
            -
             | 
| 426 | 
            -
            inline int get_num_splits(Flash_fwd_params const& params) {
         | 
| 427 | 
            -
                #ifdef FLASHATTENTION_DISABLE_SPLIT
         | 
| 428 | 
            -
                return 1;
         | 
| 429 | 
            -
                #else
         | 
| 430 | 
            -
                // Always enable PackGQA for Split
         | 
| 431 | 
            -
                // params.page_table must already be set
         | 
| 432 | 
            -
                // This needs to match the kernel configs
         | 
| 433 | 
            -
                bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
         | 
| 434 | 
            -
                auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params));
         | 
| 435 | 
            -
                // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
         | 
| 436 | 
            -
                // has not been set here. It's OK though because we might just underestimate kBlockN a bit
         | 
| 437 | 
            -
                auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
         | 
| 438 | 
            -
                int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
         | 
| 439 | 
            -
                int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
         | 
| 440 | 
            -
                int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);
         | 
| 441 | 
            -
                // If is_local, we're not going to load all of seqlen_k
         | 
| 442 | 
            -
                int const seqlen_k_loaded = !params.is_local
         | 
| 443 | 
            -
                    ? params.seqlen_k
         | 
| 444 | 
            -
                    : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM));
         | 
| 445 | 
            -
                int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN;
         | 
| 446 | 
            -
                int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;
         | 
| 447 | 
            -
                int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2);
         | 
| 448 | 
            -
                // Always enable PackGQA for Split
         | 
| 449 | 
            -
                // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits.
         | 
| 450 | 
            -
                // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending
         | 
| 451 | 
            -
                // that batch = 1.
         | 
| 452 | 
            -
                int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks;
         | 
| 453 | 
            -
                return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128);
         | 
| 454 | 
            -
                #endif
         | 
| 455 | 
            -
            }
         | 
| 456 | 
            -
             | 
| 457 | 
            -
            inline int get_max_headdim() {
         | 
| 458 | 
            -
                #ifndef FLASHATTENTION_DISABLE_HDIM256
         | 
| 459 | 
            -
                return 256;
         | 
| 460 | 
            -
                #endif
         | 
| 461 | 
            -
                #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 462 | 
            -
                return 192;
         | 
| 463 | 
            -
                #endif
         | 
| 464 | 
            -
                #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 465 | 
            -
                return 128;
         | 
| 466 | 
            -
                #endif
         | 
| 467 | 
            -
                #ifndef FLASHATTENTION_DISABLE_HDIM96
         | 
| 468 | 
            -
                return 96;
         | 
| 469 | 
            -
                #endif
         | 
| 470 | 
            -
                #ifndef FLASHATTENTION_DISABLE_HDIM64
         | 
| 471 | 
            -
                return 64;
         | 
| 472 | 
            -
                #endif
         | 
| 473 | 
            -
                return 0;
         | 
| 474 | 
            -
            }
         | 
| 475 | 
            -
             | 
| 476 | 
            -
            inline int round_up_headdim(int head_size) {
         | 
| 477 | 
            -
                #ifndef FLASHATTENTION_DISABLE_HDIM64
         | 
| 478 | 
            -
                if (head_size <= 64) { return 64; }
         | 
| 479 | 
            -
                #endif
         | 
| 480 | 
            -
                #ifndef FLASHATTENTION_DISABLE_HDIM96
         | 
| 481 | 
            -
                if (head_size <= 96) { return 96; }
         | 
| 482 | 
            -
                #endif
         | 
| 483 | 
            -
                #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 484 | 
            -
                if (head_size <= 128) { return 128; }
         | 
| 485 | 
            -
                #endif
         | 
| 486 | 
            -
                #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 487 | 
            -
                if (head_size <= 192) { return 192; }
         | 
| 488 | 
            -
                #endif
         | 
| 489 | 
            -
                #ifndef FLASHATTENTION_DISABLE_HDIM256
         | 
| 490 | 
            -
                if (head_size <= 256) { return 256; }
         | 
| 491 | 
            -
                #endif
         | 
| 492 | 
            -
                return 256;
         | 
| 493 | 
            -
            }
         | 
| 494 | 
            -
             | 
| 495 | 
            -
            inline int round_up_headdimv(int head_size) {
         | 
| 496 | 
            -
                if (head_size <= 64) { return 64; }
         | 
| 497 | 
            -
                if (head_size <= 96) { return 96; }
         | 
| 498 | 
            -
                if (head_size <= 128) { return 128; }
         | 
| 499 | 
            -
                if (head_size <= 192) { return 192; }
         | 
| 500 | 
            -
                if (head_size <= 256) { return 256; }
         | 
| 501 | 
            -
                return 512;
         | 
| 502 | 
            -
            }
         | 
| 503 | 
            -
             | 
| 504 | 
            -
            // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
         | 
| 505 | 
            -
            at::Tensor
         | 
| 506 | 
            -
            mha_fwd_get_scheduler_metadata(
         | 
| 507 | 
            -
                    int batch_size,
         | 
| 508 | 
            -
                    int max_seqlen_q,
         | 
| 509 | 
            -
                    int max_seqlen_k,
         | 
| 510 | 
            -
                    int num_heads,
         | 
| 511 | 
            -
                    int num_heads_k,
         | 
| 512 | 
            -
                    int headdim,
         | 
| 513 | 
            -
                    int headdim_v,
         | 
| 514 | 
            -
                    at::ScalarType qkv_dtype,
         | 
| 515 | 
            -
                    const at::Tensor &seqused_k, // b
         | 
| 516 | 
            -
                    std::optional<const at::Tensor> &cu_seqlens_q_,  // b+1
         | 
| 517 | 
            -
                    std::optional<const at::Tensor> &cu_seqlens_k_,  // b+1
         | 
| 518 | 
            -
                    std::optional<const at::Tensor> &cu_seqlens_k_new_,  // b+1
         | 
| 519 | 
            -
                    std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
         | 
| 520 | 
            -
                    std::optional<const at::Tensor> &leftpad_k_, // b
         | 
| 521 | 
            -
                    std::optional<int> page_size,
         | 
| 522 | 
            -
                    int max_seqlen_k_new,  // 0 means we're not appending new KV
         | 
| 523 | 
            -
                    bool is_causal,
         | 
| 524 | 
            -
                    int window_size_left,
         | 
| 525 | 
            -
                    int window_size_right,
         | 
| 526 | 
            -
                    bool has_softcap,
         | 
| 527 | 
            -
                    int num_splits,
         | 
| 528 | 
            -
                    std::optional<bool> pack_gqa_,
         | 
| 529 | 
            -
                    int const sm_margin
         | 
| 530 | 
            -
                    ) {
         | 
| 531 | 
            -
             | 
| 532 | 
            -
                TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn,
         | 
| 533 | 
            -
                            "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
         | 
| 534 | 
            -
                TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
         | 
| 535 | 
            -
             | 
| 536 | 
            -
                // Reset the parameters
         | 
| 537 | 
            -
                Flash_fwd_params params{};
         | 
| 538 | 
            -
                params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16;
         | 
| 539 | 
            -
                params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn;
         | 
| 540 | 
            -
                params.b = batch_size;
         | 
| 541 | 
            -
                params.seqlen_q = max_seqlen_q;
         | 
| 542 | 
            -
                params.seqlen_k = max_seqlen_k;
         | 
| 543 | 
            -
                params.h = num_heads;
         | 
| 544 | 
            -
                params.h_k = num_heads_k;
         | 
| 545 | 
            -
                params.d = headdim;
         | 
| 546 | 
            -
                params.dv = headdim_v;
         | 
| 547 | 
            -
                params.d_rounded = round_up_headdim(headdim);
         | 
| 548 | 
            -
                params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v);
         | 
| 549 | 
            -
                params.seqlen_knew = max_seqlen_k_new;
         | 
| 550 | 
            -
             | 
| 551 | 
            -
                bool const is_varlen_q = cu_seqlens_q_.has_value();
         | 
| 552 | 
            -
                params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr<int>() : nullptr;
         | 
| 553 | 
            -
                bool const is_varlen_k = cu_seqlens_k_.has_value();
         | 
| 554 | 
            -
                params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr<int>() : nullptr;
         | 
| 555 | 
            -
                params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr<int>() : nullptr;
         | 
| 556 | 
            -
                params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr<int>() : nullptr;
         | 
| 557 | 
            -
                params.seqused_k = seqused_k.data_ptr<int>();
         | 
| 558 | 
            -
                params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr<int>() : nullptr;
         | 
| 559 | 
            -
                params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast<int*>(1) : nullptr;
         | 
| 560 | 
            -
                if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }
         | 
| 561 | 
            -
                if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }
         | 
| 562 | 
            -
                // causal=true is the same as causal=false in this case
         | 
| 563 | 
            -
                if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) {
         | 
| 564 | 
            -
                    // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
         | 
| 565 | 
            -
                    if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) {
         | 
| 566 | 
            -
                        is_causal = false;
         | 
| 567 | 
            -
                    }
         | 
| 568 | 
            -
                }
         | 
| 569 | 
            -
                if (is_causal) { window_size_right = 0; }
         | 
| 570 | 
            -
             | 
| 571 | 
            -
                params.is_causal = window_size_left < 0 && window_size_right == 0;
         | 
| 572 | 
            -
                params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
         | 
| 573 | 
            -
                if (window_size_left < 0 && window_size_right >= 0) { window_size_left = max_seqlen_k - 1; }
         | 
| 574 | 
            -
                if (window_size_left >= 0 && window_size_right < 0) { window_size_right = max_seqlen_q - 1; }
         | 
| 575 | 
            -
                params.window_size_left = window_size_left;
         | 
| 576 | 
            -
                params.window_size_right = window_size_right;
         | 
| 577 | 
            -
                params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
         | 
| 578 | 
            -
                params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
         | 
| 579 | 
            -
                params.softcap = has_softcap ? 1.0f : 0.0f;
         | 
| 580 | 
            -
             | 
| 581 | 
            -
                params.page_size = page_size.has_value() ? page_size.value() : 1;
         | 
| 582 | 
            -
                params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast<int*>(1);
         | 
| 583 | 
            -
             | 
| 584 | 
            -
                bool const use_dynamic_split = params.b <= 992;
         | 
| 585 | 
            -
                params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
         | 
| 586 | 
            -
             | 
| 587 | 
            -
                params.pagedkv_tma = get_pagedkv_tma(params);
         | 
| 588 | 
            -
                // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits)
         | 
| 589 | 
            -
                params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
         | 
| 590 | 
            -
                params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
         | 
| 591 | 
            -
                // Always enable PackGQA for Split
         | 
| 592 | 
            -
                params.pack_gqa = params.num_splits > 1;
         | 
| 593 | 
            -
             | 
| 594 | 
            -
                bool is_varlen = true;
         | 
| 595 | 
            -
             | 
| 596 | 
            -
                // Otherwise the kernel will be launched from cuda:0 device
         | 
| 597 | 
            -
                // Cast to char to avoid compiler warning about narrowing
         | 
| 598 | 
            -
                at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()};
         | 
| 599 | 
            -
             | 
| 600 | 
            -
                auto opts = seqused_k.options();
         | 
| 601 | 
            -
                // This needs to be set after get_num_splits
         | 
| 602 | 
            -
                at::Tensor tile_count_semaphore;  // Contains the semaphore and optionally num_splits_dynamic
         | 
| 603 | 
            -
                bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1;
         | 
| 604 | 
            -
                if (scheduler_needs_semaphore || use_dynamic_split) {
         | 
| 605 | 
            -
                    tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32));
         | 
| 606 | 
            -
                    if (scheduler_needs_semaphore) {
         | 
| 607 | 
            -
                        if (!use_dynamic_split) { tile_count_semaphore.zero_(); }  // If varlen we'll manually do the zero-ing
         | 
| 608 | 
            -
                        params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
         | 
| 609 | 
            -
                    } else {
         | 
| 610 | 
            -
                        params.tile_count_semaphore = nullptr;
         | 
| 611 | 
            -
                    }
         | 
| 612 | 
            -
                    params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
         | 
| 613 | 
            -
                }
         | 
| 614 | 
            -
             | 
| 615 | 
            -
                if (params.num_splits_dynamic_ptr) {
         | 
| 616 | 
            -
                    auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params));
         | 
| 617 | 
            -
                    auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
         | 
| 618 | 
            -
                    int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
         | 
| 619 | 
            -
                    int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
         | 
| 620 | 
            -
                    auto stream = at::cuda::getCurrentCUDAStream().stream();
         | 
| 621 | 
            -
                    prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/);
         | 
| 622 | 
            -
                    CHECK_CUDA_KERNEL_LAUNCH();
         | 
| 623 | 
            -
                }
         | 
| 624 | 
            -
                return tile_count_semaphore;
         | 
| 625 | 
            -
            }
         | 
| 626 | 
            -
             | 
| 627 | 
            -
            // b: batch_size
         | 
| 628 | 
            -
            // b_k: batch_size_k
         | 
| 629 | 
            -
            // s_q: seqlen_q
         | 
| 630 | 
            -
            // s_k: seqlen_k
         | 
| 631 | 
            -
            // s_k_new: seqlen_k_new
         | 
| 632 | 
            -
            // h: num_heads
         | 
| 633 | 
            -
            // h_k: num_heads_k
         | 
| 634 | 
            -
            // d: head_size
         | 
| 635 | 
            -
            std::vector<at::Tensor>
         | 
| 636 | 
            -
            mha_fwd(at::Tensor &q,   // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
         | 
| 637 | 
            -
                    const at::Tensor &k,  // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
         | 
| 638 | 
            -
                    const at::Tensor &v,  // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table.
         | 
| 639 | 
            -
                    std::optional<const at::Tensor> &k_new_,  // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
         | 
| 640 | 
            -
                    std::optional<const at::Tensor> &v_new_,  // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
         | 
| 641 | 
            -
                    std::optional<const at::Tensor> &q_v_,  // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
         | 
| 642 | 
            -
                    std::optional<at::Tensor> &out_,  // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
         | 
| 643 | 
            -
                    std::optional<const at::Tensor> &cu_seqlens_q_,  // b+1
         | 
| 644 | 
            -
                    std::optional<const at::Tensor> &cu_seqlens_k_,  // b+1
         | 
| 645 | 
            -
                    std::optional<const at::Tensor> &cu_seqlens_k_new_,  // b+1
         | 
| 646 | 
            -
                    std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
         | 
| 647 | 
            -
                    std::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
         | 
| 648 | 
            -
                    std::optional<int> max_seqlen_q_,
         | 
| 649 | 
            -
                    // TODO: check if we need max_seqlen_k
         | 
| 650 | 
            -
                    std::optional<int> max_seqlen_k_,
         | 
| 651 | 
            -
                    std::optional<const at::Tensor> &page_table_, // (b_k, max_num_pages_per_seq)
         | 
| 652 | 
            -
                    std::optional<const at::Tensor> &kv_batch_idx_, // b. indices to index into the KV cache
         | 
| 653 | 
            -
                    std::optional<const at::Tensor> &leftpad_k_, // b
         | 
| 654 | 
            -
                    std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
         | 
| 655 | 
            -
                    std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
         | 
| 656 | 
            -
                    std::optional<const at::Tensor> &seqlens_rotary_, // b
         | 
| 657 | 
            -
                    std::optional<at::Tensor> &q_descale_,  // (b, h_k), not (b, h)
         | 
| 658 | 
            -
                    std::optional<at::Tensor> &k_descale_,  // (b, h_k)
         | 
| 659 | 
            -
                    std::optional<at::Tensor> &v_descale_,  // (b, h_k)
         | 
| 660 | 
            -
                    float const softmax_scale,
         | 
| 661 | 
            -
                    bool is_causal,
         | 
| 662 | 
            -
                    int window_size_left,
         | 
| 663 | 
            -
                    int window_size_right,
         | 
| 664 | 
            -
                    float const softcap,
         | 
| 665 | 
            -
                    bool const is_rotary_interleaved,   // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
         | 
| 666 | 
            -
                    std::optional<at::Tensor> &scheduler_metadata_,  // (b + 1)
         | 
| 667 | 
            -
                    int num_splits,
         | 
| 668 | 
            -
                    std::optional<bool> pack_gqa_,
         | 
| 669 | 
            -
                    int const sm_margin,
         | 
| 670 | 
            -
                    std::optional<const at::Tensor> &s_aux_ // (h)
         | 
| 671 | 
            -
                    ) {
         | 
| 672 | 
            -
             | 
| 673 | 
            -
                auto dprops = at::cuda::getCurrentDeviceProperties();
         | 
| 674 | 
            -
                bool is_sm8x = dprops->major >= 8;
         | 
| 675 | 
            -
                TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
         | 
| 676 | 
            -
             | 
| 677 | 
            -
                auto q_type = q.scalar_type();
         | 
| 678 | 
            -
                TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
         | 
| 679 | 
            -
                            "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
         | 
| 680 | 
            -
                if (dprops->major < 9) {
         | 
| 681 | 
            -
                    TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
         | 
| 682 | 
            -
                                "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type");
         | 
| 683 | 
            -
                }
         | 
| 684 | 
            -
                TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
         | 
| 685 | 
            -
                TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
         | 
| 686 | 
            -
             | 
| 687 | 
            -
                CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
         | 
| 688 | 
            -
             | 
| 689 | 
            -
                TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
         | 
| 690 | 
            -
                TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
         | 
| 691 | 
            -
                TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
         | 
| 692 | 
            -
             | 
| 693 | 
            -
                at::Tensor page_table;
         | 
| 694 | 
            -
                const bool paged_KV = page_table_.has_value();
         | 
| 695 | 
            -
                if (paged_KV) {
         | 
| 696 | 
            -
                    page_table = page_table_.value();
         | 
| 697 | 
            -
                    CHECK_DEVICE(page_table);
         | 
| 698 | 
            -
                    TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32");
         | 
| 699 | 
            -
                    TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension");
         | 
| 700 | 
            -
                }
         | 
| 701 | 
            -
             | 
| 702 | 
            -
                at::Tensor cu_seqlens_q;
         | 
| 703 | 
            -
                bool const is_varlen_q = cu_seqlens_q_.has_value();
         | 
| 704 | 
            -
                if (is_varlen_q) {
         | 
| 705 | 
            -
                    cu_seqlens_q = cu_seqlens_q_.value();
         | 
| 706 | 
            -
                    CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
         | 
| 707 | 
            -
                    TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
         | 
| 708 | 
            -
                    TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
         | 
| 709 | 
            -
                }
         | 
| 710 | 
            -
                at::Tensor cu_seqlens_k;
         | 
| 711 | 
            -
                bool const is_varlen_k = cu_seqlens_k_.has_value();
         | 
| 712 | 
            -
                if (is_varlen_k) {
         | 
| 713 | 
            -
                    cu_seqlens_k = cu_seqlens_k_.value();
         | 
| 714 | 
            -
                    CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
         | 
| 715 | 
            -
                    TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
         | 
| 716 | 
            -
                    TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
         | 
| 717 | 
            -
                    TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported");
         | 
| 718 | 
            -
                    TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported");
         | 
| 719 | 
            -
                }
         | 
| 720 | 
            -
             | 
| 721 | 
            -
                auto const sizes = q.sizes();
         | 
| 722 | 
            -
                const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
         | 
| 723 | 
            -
                int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
         | 
| 724 | 
            -
                int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
         | 
| 725 | 
            -
                int num_heads = q.size(-2);
         | 
| 726 | 
            -
                int const head_size = q.size(-1);
         | 
| 727 | 
            -
                int const head_size_v = v.size(-1);
         | 
| 728 | 
            -
                int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);
         | 
| 729 | 
            -
                int const num_pages = !paged_KV ? 0 : k.size(0);
         | 
| 730 | 
            -
                int const page_size = !paged_KV ? 1 : k.size(1);
         | 
| 731 | 
            -
                int const seqlen_k = !max_seqlen_k_.has_value() ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();
         | 
| 732 | 
            -
                int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
         | 
| 733 | 
            -
                int const num_heads_k = k.size(-2);
         | 
| 734 | 
            -
                int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);
         | 
| 735 | 
            -
                if (!kv_batch_idx_.has_value()) {
         | 
| 736 | 
            -
                    TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k");
         | 
| 737 | 
            -
                }
         | 
| 738 | 
            -
                int const max_headdim = get_max_headdim();
         | 
| 739 | 
            -
                TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
         | 
| 740 | 
            -
                TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
         | 
| 741 | 
            -
                if (head_size_v != head_size) {
         | 
| 742 | 
            -
                    TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) ||
         | 
| 743 | 
            -
                               (head_size <= 64 && head_size_v <= 512),
         | 
| 744 | 
            -
                               "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], "
         | 
| 745 | 
            -
                               "or (Q/K <= 64 and V <= 512).");
         | 
| 746 | 
            -
                    TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim");
         | 
| 747 | 
            -
                    if (head_size_v > 256) {
         | 
| 748 | 
            -
                        TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
         | 
| 749 | 
            -
                                    "HeaddimV > 256 requires fp16 and bf16 data type");
         | 
| 750 | 
            -
                    }
         | 
| 751 | 
            -
                }
         | 
| 752 | 
            -
             | 
| 753 | 
            -
                // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
         | 
| 754 | 
            -
                // TODO: check this
         | 
| 755 | 
            -
                if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
         | 
| 756 | 
            -
                if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
         | 
| 757 | 
            -
                // causal=true is the same as causal=false in this case
         | 
| 758 | 
            -
                if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) {
         | 
| 759 | 
            -
                    // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
         | 
| 760 | 
            -
                    if ((head_size <= 64 || head_size > 128) || !paged_KV) {
         | 
| 761 | 
            -
                        is_causal = false;
         | 
| 762 | 
            -
                    }
         | 
| 763 | 
            -
                }
         | 
| 764 | 
            -
                if (is_causal) { window_size_right = 0; }
         | 
| 765 | 
            -
                // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_fprop will set params.is_causal=true.
         | 
| 766 | 
            -
                // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM.
         | 
| 767 | 
            -
                is_causal = window_size_left < 0 && window_size_right == 0;
         | 
| 768 | 
            -
             | 
| 769 | 
            -
                if (!is_varlen_q) {
         | 
| 770 | 
            -
                    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
         | 
| 771 | 
            -
                } else {
         | 
| 772 | 
            -
                    CHECK_SHAPE(q, total_q, num_heads, head_size);
         | 
| 773 | 
            -
                    CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
         | 
| 774 | 
            -
                }
         | 
| 775 | 
            -
                if (!paged_KV) {
         | 
| 776 | 
            -
                    if (!is_varlen_k) {
         | 
| 777 | 
            -
                        CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size);
         | 
| 778 | 
            -
                        CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v);
         | 
| 779 | 
            -
                    } else {
         | 
| 780 | 
            -
                        CHECK_SHAPE(k, total_k, num_heads_k, head_size);
         | 
| 781 | 
            -
                        CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);
         | 
| 782 | 
            -
                        CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
         | 
| 783 | 
            -
                    }
         | 
| 784 | 
            -
                } else {
         | 
| 785 | 
            -
                    CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size);
         | 
| 786 | 
            -
                    CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v);
         | 
| 787 | 
            -
                    CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq);
         | 
| 788 | 
            -
                }
         | 
| 789 | 
            -
             | 
| 790 | 
            -
                if (seqused_q_.has_value()){
         | 
| 791 | 
            -
                    auto seqused_q = seqused_q_.value();
         | 
| 792 | 
            -
                    TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
         | 
| 793 | 
            -
                    CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
         | 
| 794 | 
            -
                    CHECK_SHAPE(seqused_q, batch_size);
         | 
| 795 | 
            -
                }
         | 
| 796 | 
            -
                if (seqused_k_.has_value()) {
         | 
| 797 | 
            -
                    auto seqused_k = seqused_k_.value();
         | 
| 798 | 
            -
                    TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
         | 
| 799 | 
            -
                    CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
         | 
| 800 | 
            -
                    CHECK_SHAPE(seqused_k, batch_size);
         | 
| 801 | 
            -
                }
         | 
| 802 | 
            -
             | 
| 803 | 
            -
                if (leftpad_k_.has_value()) {
         | 
| 804 | 
            -
                    auto leftpad_k = leftpad_k_.value();
         | 
| 805 | 
            -
                    TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
         | 
| 806 | 
            -
                    CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k);
         | 
| 807 | 
            -
                    CHECK_SHAPE(leftpad_k, batch_size);
         | 
| 808 | 
            -
                }
         | 
| 809 | 
            -
             | 
| 810 | 
            -
                // This is what we will template on
         | 
| 811 | 
            -
                bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value();
         | 
| 812 | 
            -
                #ifdef FLASHATTENTION_DISABLE_VARLEN
         | 
| 813 | 
            -
                    TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
         | 
| 814 | 
            -
                #endif
         | 
| 815 | 
            -
             | 
| 816 | 
            -
                int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
         | 
| 817 | 
            -
                TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment));
         | 
| 818 | 
            -
                TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment));
         | 
| 819 | 
            -
             | 
| 820 | 
            -
                auto opts = q.options();
         | 
| 821 | 
            -
                auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
         | 
| 822 | 
            -
                at::Tensor out;
         | 
| 823 | 
            -
                if (out_.has_value()) {
         | 
| 824 | 
            -
                    out = out_.value();
         | 
| 825 | 
            -
                    TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
         | 
| 826 | 
            -
                    CHECK_DEVICE(out);
         | 
| 827 | 
            -
                    TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
         | 
| 828 | 
            -
                    if (!is_varlen_q) {
         | 
| 829 | 
            -
                        CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);
         | 
| 830 | 
            -
                    } else {
         | 
| 831 | 
            -
                        CHECK_SHAPE(out, total_q, num_heads, head_size_v);
         | 
| 832 | 
            -
                    }
         | 
| 833 | 
            -
                } else {
         | 
| 834 | 
            -
                    out = !is_varlen_q
         | 
| 835 | 
            -
                        ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type))
         | 
| 836 | 
            -
                        : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type));
         | 
| 837 | 
            -
                }
         | 
| 838 | 
            -
             | 
| 839 | 
            -
                auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
         | 
| 840 | 
            -
                int const head_size_rounded = round_up_headdim(head_size);
         | 
| 841 | 
            -
                int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v);
         | 
| 842 | 
            -
                int const seqlen_q_rounded = round_multiple(seqlen_q, 128);
         | 
| 843 | 
            -
                int const seqlen_k_rounded = round_multiple(seqlen_k, 128);
         | 
| 844 | 
            -
             | 
| 845 | 
            -
                // Otherwise the kernel will be launched from cuda:0 device
         | 
| 846 | 
            -
                // Cast to char to avoid compiler warning about narrowing
         | 
| 847 | 
            -
                at::cuda::CUDAGuard device_guard{(char)q.get_device()};
         | 
| 848 | 
            -
             | 
| 849 | 
            -
                at::Tensor softmax_lse;
         | 
| 850 | 
            -
                if (!is_varlen_q) {
         | 
| 851 | 
            -
                    softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
         | 
| 852 | 
            -
                } else {
         | 
| 853 | 
            -
                    softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
         | 
| 854 | 
            -
                }
         | 
| 855 | 
            -
             | 
| 856 | 
            -
                Flash_fwd_params params;
         | 
| 857 | 
            -
                set_params_fprop(params,
         | 
| 858 | 
            -
                                 batch_size,
         | 
| 859 | 
            -
                                 seqlen_q, seqlen_k,
         | 
| 860 | 
            -
                                 seqlen_q_rounded, seqlen_k_rounded,
         | 
| 861 | 
            -
                                 num_heads, num_heads_k,
         | 
| 862 | 
            -
                                 head_size, head_size_rounded,
         | 
| 863 | 
            -
                                 q, k, v, out,
         | 
| 864 | 
            -
                                 !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
         | 
| 865 | 
            -
                                 !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
         | 
| 866 | 
            -
                                 seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
         | 
| 867 | 
            -
                                 seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
         | 
| 868 | 
            -
                                 softmax_lse.data_ptr(),
         | 
| 869 | 
            -
                                 /*p_dropout=*/0.f,
         | 
| 870 | 
            -
                                 softmax_scale,
         | 
| 871 | 
            -
                                 window_size_left,
         | 
| 872 | 
            -
                                 window_size_right,
         | 
| 873 | 
            -
                                 softcap,
         | 
| 874 | 
            -
                                 sm_margin);
         | 
| 875 | 
            -
                params.total_q = total_q;
         | 
| 876 | 
            -
                params.total_k = total_k;
         | 
| 877 | 
            -
                params.b_k = batch_size_k;
         | 
| 878 | 
            -
                params.dv = head_size_v;
         | 
| 879 | 
            -
                params.dv_rounded = head_size_v_rounded;
         | 
| 880 | 
            -
                if (leftpad_k_.has_value()) {  // This needs to be set before get_pagedkv_tma
         | 
| 881 | 
            -
                    params.leftpad_k = static_cast<int *>(leftpad_k_.value().data_ptr());
         | 
| 882 | 
            -
                }
         | 
| 883 | 
            -
                if (paged_KV) {
         | 
| 884 | 
            -
                    params.page_table = page_table.data_ptr<int>();
         | 
| 885 | 
            -
                    params.page_table_batch_stride = page_table.stride(0);
         | 
| 886 | 
            -
                }
         | 
| 887 | 
            -
                params.page_size = page_size;
         | 
| 888 | 
            -
                params.num_pages = num_pages;
         | 
| 889 | 
            -
             | 
| 890 | 
            -
                if (k_new_.has_value()) {  // This needs to be set before get_pagedkv_tma
         | 
| 891 | 
            -
                    at::Tensor k_new, v_new;
         | 
| 892 | 
            -
                    TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in");
         | 
| 893 | 
            -
                    TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in");
         | 
| 894 | 
            -
                    TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache");
         | 
| 895 | 
            -
                    at::Tensor cu_seqlens_k_new;
         | 
| 896 | 
            -
                    bool const is_varlen_k_new = cu_seqlens_k_new_.has_value();
         | 
| 897 | 
            -
                    if (is_varlen_k_new) {
         | 
| 898 | 
            -
                        cu_seqlens_k_new = cu_seqlens_k_new_.value();
         | 
| 899 | 
            -
                        CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new);
         | 
| 900 | 
            -
                        TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32");
         | 
| 901 | 
            -
                    }
         | 
| 902 | 
            -
                    k_new = k_new_.value();
         | 
| 903 | 
            -
                    v_new = v_new_.value();
         | 
| 904 | 
            -
                    TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query");
         | 
| 905 | 
            -
                    TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query");
         | 
| 906 | 
            -
                    CHECK_DEVICE(k_new); CHECK_DEVICE(v_new);
         | 
| 907 | 
            -
                    TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension");
         | 
| 908 | 
            -
                    TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension");
         | 
| 909 | 
            -
                    // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new
         | 
| 910 | 
            -
                    int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0;
         | 
| 911 | 
            -
                    int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0);
         | 
| 912 | 
            -
                    if (!is_varlen_k_new) {
         | 
| 913 | 
            -
                        CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size);
         | 
| 914 | 
            -
                        CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v);
         | 
| 915 | 
            -
                    } else {
         | 
| 916 | 
            -
                        CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size);
         | 
| 917 | 
            -
                        CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v);
         | 
| 918 | 
            -
                        CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1);
         | 
| 919 | 
            -
                    }
         | 
| 920 | 
            -
                    params.seqlen_knew = seqlen_k_new;
         | 
| 921 | 
            -
                    params.total_knew = total_k_new;
         | 
| 922 | 
            -
                    params.knew_ptr = k_new.data_ptr();
         | 
| 923 | 
            -
                    params.vnew_ptr = v_new.data_ptr();
         | 
| 924 | 
            -
                    // All stride are in elements, not bytes.
         | 
| 925 | 
            -
                    params.knew_row_stride = k_new.stride(-3);
         | 
| 926 | 
            -
                    params.vnew_row_stride = v_new.stride(-3);
         | 
| 927 | 
            -
                    params.knew_head_stride = k_new.stride(-2);
         | 
| 928 | 
            -
                    params.vnew_head_stride = v_new.stride(-2);
         | 
| 929 | 
            -
                    if (!is_varlen_k_new) {
         | 
| 930 | 
            -
                        params.knew_batch_stride = k_new.stride(0);
         | 
| 931 | 
            -
                        params.vnew_batch_stride = v_new.stride(0);
         | 
| 932 | 
            -
                    }
         | 
| 933 | 
            -
                    if (is_varlen_k_new) {
         | 
| 934 | 
            -
                        params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());
         | 
| 935 | 
            -
                    }
         | 
| 936 | 
            -
                }
         | 
| 937 | 
            -
             | 
| 938 | 
            -
                // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel
         | 
| 939 | 
            -
                bool const use_dynamic_split = is_varlen && params.b <= 992;
         | 
| 940 | 
            -
                // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it
         | 
| 941 | 
            -
                params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
         | 
| 942 | 
            -
             | 
| 943 | 
            -
                params.pagedkv_tma = get_pagedkv_tma(params);
         | 
| 944 | 
            -
                // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits)
         | 
| 945 | 
            -
                params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
         | 
| 946 | 
            -
                params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
         | 
| 947 | 
            -
                // Always enable PackGQA for Split
         | 
| 948 | 
            -
                params.pack_gqa = params.num_splits > 1;
         | 
| 949 | 
            -
             | 
| 950 | 
            -
                // This needs to be set after get_num_splits
         | 
| 951 | 
            -
                at::Tensor tile_count_semaphore;  // Contains the semaphore and optionally num_splits_dynamic
         | 
| 952 | 
            -
                // We don't use the persistent scheduler if Split and not Varlen
         | 
| 953 | 
            -
                bool const scheduler_needs_semaphore = params.arch >= 90
         | 
| 954 | 
            -
                    ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)
         | 
| 955 | 
            -
                    : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));
         | 
| 956 | 
            -
                if (scheduler_needs_semaphore || use_dynamic_split) {
         | 
| 957 | 
            -
                    int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b;
         | 
| 958 | 
            -
                    params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value();
         | 
| 959 | 
            -
                    if (scheduler_metadata_.has_value()) {
         | 
| 960 | 
            -
                        at::Tensor scheduler_metadata = scheduler_metadata_.value();
         | 
| 961 | 
            -
                        CHECK_DEVICE(scheduler_metadata);
         | 
| 962 | 
            -
                        CHECK_SHAPE(scheduler_metadata, metadata_size);
         | 
| 963 | 
            -
                        CHECK_CONTIGUOUS(scheduler_metadata);
         | 
| 964 | 
            -
                        TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32");
         | 
| 965 | 
            -
                        tile_count_semaphore = scheduler_metadata;
         | 
| 966 | 
            -
                    } else {
         | 
| 967 | 
            -
                        tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32));
         | 
| 968 | 
            -
                    }
         | 
| 969 | 
            -
                    if (scheduler_needs_semaphore && !use_dynamic_split) {
         | 
| 970 | 
            -
                        tile_count_semaphore.zero_();  // If varlen we'll manually do the zero-ing
         | 
| 971 | 
            -
                    }
         | 
| 972 | 
            -
                    params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr<int>() : nullptr;
         | 
| 973 | 
            -
                    params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
         | 
| 974 | 
            -
                }
         | 
| 975 | 
            -
             | 
| 976 | 
            -
                if (q_v_.has_value()) {
         | 
| 977 | 
            -
                    TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64");
         | 
| 978 | 
            -
                    TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
         | 
| 979 | 
            -
                                "q_v is only supported for fp16 and bf16 data type");
         | 
| 980 | 
            -
                    TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs");
         | 
| 981 | 
            -
                    at::Tensor q_v = q_v_.value();
         | 
| 982 | 
            -
                    TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query");
         | 
| 983 | 
            -
                    CHECK_DEVICE(q_v);
         | 
| 984 | 
            -
                    TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension");
         | 
| 985 | 
            -
                    if (!is_varlen_q) {
         | 
| 986 | 
            -
                        CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v);
         | 
| 987 | 
            -
                    } else {
         | 
| 988 | 
            -
                        CHECK_SHAPE(q_v, total_q, num_heads, head_size_v);
         | 
| 989 | 
            -
                    }
         | 
| 990 | 
            -
                    params.qv_ptr = q_v.data_ptr();
         | 
| 991 | 
            -
                    // All stride are in elements, not bytes.
         | 
| 992 | 
            -
                    params.qv_row_stride = q_v.stride(-3);
         | 
| 993 | 
            -
                    params.qv_head_stride = q_v.stride(-2);
         | 
| 994 | 
            -
                    if (!is_varlen_q) {
         | 
| 995 | 
            -
                        params.qv_batch_stride = q_v.stride(0);
         | 
| 996 | 
            -
                    }
         | 
| 997 | 
            -
                }
         | 
| 998 | 
            -
             | 
| 999 | 
            -
                if (rotary_cos_.has_value()) {
         | 
| 1000 | 
            -
                    TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
         | 
| 1001 | 
            -
                    auto rotary_cos = rotary_cos_.value();
         | 
| 1002 | 
            -
                    CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos);
         | 
| 1003 | 
            -
                    params.rotary_dim = rotary_cos.size(1) * 2;
         | 
| 1004 | 
            -
                    TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
         | 
| 1005 | 
            -
                    TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
         | 
| 1006 | 
            -
                    const int seqlen_ro = rotary_cos.size(0);
         | 
| 1007 | 
            -
                    if (paged_KV) {
         | 
| 1008 | 
            -
                        TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
         | 
| 1009 | 
            -
                    }
         | 
| 1010 | 
            -
                    CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
         | 
| 1011 | 
            -
                    TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
         | 
| 1012 | 
            -
             | 
| 1013 | 
            -
                    TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
         | 
| 1014 | 
            -
                    auto rotary_sin = rotary_sin_.value();
         | 
| 1015 | 
            -
                    CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin);
         | 
| 1016 | 
            -
                    CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
         | 
| 1017 | 
            -
                    TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
         | 
| 1018 | 
            -
                    params.rotary_cos_ptr = rotary_cos.data_ptr();
         | 
| 1019 | 
            -
                    params.rotary_sin_ptr = rotary_sin.data_ptr();
         | 
| 1020 | 
            -
                    params.is_rotary_interleaved = is_rotary_interleaved;
         | 
| 1021 | 
            -
                    if (seqlens_rotary_.has_value()) {
         | 
| 1022 | 
            -
                        at::Tensor seqlens_rotary = seqlens_rotary_.value();
         | 
| 1023 | 
            -
                        CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary);
         | 
| 1024 | 
            -
                        TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32");
         | 
| 1025 | 
            -
                        CHECK_SHAPE(seqlens_rotary, batch_size);
         | 
| 1026 | 
            -
                        params.seqlens_rotary = seqlens_rotary.data_ptr<int>();
         | 
| 1027 | 
            -
                    }
         | 
| 1028 | 
            -
                } else {
         | 
| 1029 | 
            -
                    params.rotary_dim = 0;
         | 
| 1030 | 
            -
                }
         | 
| 1031 | 
            -
             | 
| 1032 | 
            -
                if (kv_batch_idx_.has_value()) {
         | 
| 1033 | 
            -
                    auto kv_batch_idx = kv_batch_idx_.value();
         | 
| 1034 | 
            -
                    CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx);
         | 
| 1035 | 
            -
                    TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32");
         | 
| 1036 | 
            -
                    params.kv_batch_idx = reinterpret_cast<int *>(kv_batch_idx.data_ptr());
         | 
| 1037 | 
            -
                }
         | 
| 1038 | 
            -
             | 
| 1039 | 
            -
                at::Tensor out_accum, softmax_lse_accum;
         | 
| 1040 | 
            -
                auto outaccum_type = at::ScalarType::Float;
         | 
| 1041 | 
            -
                if (params.num_splits > 1) {
         | 
| 1042 | 
            -
                    TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported");
         | 
| 1043 | 
            -
                    if (!is_varlen_q) {
         | 
| 1044 | 
            -
                        out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type));
         | 
| 1045 | 
            -
                        softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
         | 
| 1046 | 
            -
                        params.oaccum_batch_stride = out_accum.stride(1);
         | 
| 1047 | 
            -
                        params.lseaccum_batch_stride = softmax_lse_accum.stride(1);
         | 
| 1048 | 
            -
                    } else {
         | 
| 1049 | 
            -
                        out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type));
         | 
| 1050 | 
            -
                        softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));
         | 
| 1051 | 
            -
                    }
         | 
| 1052 | 
            -
                    params.is_fp32 = false;
         | 
| 1053 | 
            -
                    params.oaccum_ptr = out_accum.data_ptr();
         | 
| 1054 | 
            -
                    params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
         | 
| 1055 | 
            -
                    params.oaccum_split_stride = out_accum.stride(0);
         | 
| 1056 | 
            -
                    params.oaccum_row_stride = out_accum.stride(-2);
         | 
| 1057 | 
            -
                    params.oaccum_head_stride = out_accum.stride(-3);
         | 
| 1058 | 
            -
                    params.lseaccum_split_stride = softmax_lse_accum.stride(0);
         | 
| 1059 | 
            -
                    params.lseaccum_head_stride = softmax_lse_accum.stride(-2);
         | 
| 1060 | 
            -
                }
         | 
| 1061 | 
            -
             | 
| 1062 | 
            -
                if (q_type == at::ScalarType::Float8_e4m3fn) {
         | 
| 1063 | 
            -
                    if (q_descale_.has_value()) {
         | 
| 1064 | 
            -
                        auto q_descale = q_descale_.value();
         | 
| 1065 | 
            -
                        CHECK_DEVICE(q_descale);
         | 
| 1066 | 
            -
                        CHECK_SHAPE(q_descale, batch_size, num_heads_k);
         | 
| 1067 | 
            -
                        params.q_descale_ptr = q_descale.data_ptr<float>();
         | 
| 1068 | 
            -
                        params.q_descale_batch_stride = q_descale.stride(0);
         | 
| 1069 | 
            -
                        params.q_descale_head_stride = q_descale.stride(1);
         | 
| 1070 | 
            -
                    } else {
         | 
| 1071 | 
            -
                        params.q_descale_ptr = nullptr;
         | 
| 1072 | 
            -
                    }
         | 
| 1073 | 
            -
                    if (k_descale_.has_value()) {
         | 
| 1074 | 
            -
                        auto k_descale = k_descale_.value();
         | 
| 1075 | 
            -
                        CHECK_DEVICE(k_descale);
         | 
| 1076 | 
            -
                        CHECK_SHAPE(k_descale, batch_size, num_heads_k);
         | 
| 1077 | 
            -
                        params.k_descale_ptr = k_descale.data_ptr<float>();
         | 
| 1078 | 
            -
                        params.k_descale_batch_stride = k_descale.stride(0);
         | 
| 1079 | 
            -
                        params.k_descale_head_stride = k_descale.stride(1);
         | 
| 1080 | 
            -
                    } else {
         | 
| 1081 | 
            -
                        params.k_descale_ptr = nullptr;
         | 
| 1082 | 
            -
                    }
         | 
| 1083 | 
            -
                    if (v_descale_.has_value()) {
         | 
| 1084 | 
            -
                        auto v_descale = v_descale_.value();
         | 
| 1085 | 
            -
                        CHECK_DEVICE(v_descale);
         | 
| 1086 | 
            -
                        CHECK_SHAPE(v_descale, batch_size, num_heads_k);
         | 
| 1087 | 
            -
                        params.v_descale_ptr = v_descale.data_ptr<float>();
         | 
| 1088 | 
            -
                        params.v_descale_batch_stride = v_descale.stride(0);
         | 
| 1089 | 
            -
                        params.v_descale_head_stride = v_descale.stride(1);
         | 
| 1090 | 
            -
                    } else {
         | 
| 1091 | 
            -
                        params.v_descale_ptr = nullptr;
         | 
| 1092 | 
            -
                    }
         | 
| 1093 | 
            -
                }
         | 
| 1094 | 
            -
             | 
| 1095 | 
            -
                if(s_aux_.has_value()) {
         | 
| 1096 | 
            -
                    auto s_aux = s_aux_.value();
         | 
| 1097 | 
            -
                    TORCH_CHECK(s_aux.scalar_type() == at::ScalarType::BFloat16,
         | 
| 1098 | 
            -
                        "We only support bf16 dtype for S extra.");
         | 
| 1099 | 
            -
                    CHECK_DEVICE(s_aux);
         | 
| 1100 | 
            -
                    CHECK_SHAPE(s_aux, num_heads);
         | 
| 1101 | 
            -
                    CHECK_CONTIGUOUS(s_aux);
         | 
| 1102 | 
            -
                    params.s_aux_ptr = s_aux.data_ptr();
         | 
| 1103 | 
            -
                } else {
         | 
| 1104 | 
            -
                    params.s_aux_ptr = nullptr;
         | 
| 1105 | 
            -
                }
         | 
| 1106 | 
            -
             | 
| 1107 | 
            -
                #ifdef FLASHATTENTION_DISABLE_LOCAL
         | 
| 1108 | 
            -
                TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
         | 
| 1109 | 
            -
                #endif
         | 
| 1110 | 
            -
                #ifdef FLASHATTENTION_DISABLE_SOFTCAP
         | 
| 1111 | 
            -
                TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
         | 
| 1112 | 
            -
                #endif
         | 
| 1113 | 
            -
                #ifdef FLASHATTENTION_DISABLE_SPLIT
         | 
| 1114 | 
            -
                TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
         | 
| 1115 | 
            -
                #endif
         | 
| 1116 | 
            -
                #ifdef FLASHATTENTION_DISABLE_PACKGQA
         | 
| 1117 | 
            -
                TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa.");
         | 
| 1118 | 
            -
                #endif
         | 
| 1119 | 
            -
                #ifdef FLASHATTENTION_DISABLE_PAGEDKV
         | 
| 1120 | 
            -
                TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV.");
         | 
| 1121 | 
            -
                #endif
         | 
| 1122 | 
            -
                #ifdef FLASHATTENTION_DISABLE_APPENDKV
         | 
| 1123 | 
            -
                TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV.");
         | 
| 1124 | 
            -
                #endif
         | 
| 1125 | 
            -
             | 
| 1126 | 
            -
                if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) {
         | 
| 1127 | 
            -
                    auto stream = at::cuda::getCurrentCUDAStream().stream();
         | 
| 1128 | 
            -
                    run_mha_fwd(params, stream);
         | 
| 1129 | 
            -
                    if (params.num_splits > 1) {
         | 
| 1130 | 
            -
                        if (out_type == at::ScalarType::BFloat16) {
         | 
| 1131 | 
            -
                            // Since we want output in BF16. Otherwise fwd_combine will output to FP16
         | 
| 1132 | 
            -
                            params.is_bf16 = true;
         | 
| 1133 | 
            -
                        }
         | 
| 1134 | 
            -
                        // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1
         | 
| 1135 | 
            -
                        // and seqlen = total_q, and don't need to dispatch to Varlen there.
         | 
| 1136 | 
            -
                        // However, with dynamic split, each row needs to know which batch it belongs to
         | 
| 1137 | 
            -
                        // to read the number of splits, so we just use the varlen version of combine kernel.
         | 
| 1138 | 
            -
                        // if (is_varlen_q && !seqused_q_.has_value()) {
         | 
| 1139 | 
            -
                        // if (is_varlen_q) {
         | 
| 1140 | 
            -
                        //     params.b = 1;
         | 
| 1141 | 
            -
                        //     params.seqlen_q = total_q;
         | 
| 1142 | 
            -
                        // }
         | 
| 1143 | 
            -
                        // This will zero out the semaphore if needed
         | 
| 1144 | 
            -
                        run_mha_fwd_combine(params, stream, true /*enable_pdl*/);
         | 
| 1145 | 
            -
                    } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) {
         | 
| 1146 | 
            -
                        // need to zero out the semaphore in this case
         | 
| 1147 | 
            -
                        tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_();
         | 
| 1148 | 
            -
                    }
         | 
| 1149 | 
            -
                } else if (total_q > 0 && num_heads_k > 0) {
         | 
| 1150 | 
            -
                    // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
         | 
| 1151 | 
            -
                    out.zero_();
         | 
| 1152 | 
            -
                    softmax_lse.fill_(std::numeric_limits<float>::infinity());
         | 
| 1153 | 
            -
                }
         | 
| 1154 | 
            -
             | 
| 1155 | 
            -
                // return {out, softmax_lse};
         | 
| 1156 | 
            -
                return {out, softmax_lse, out_accum, softmax_lse_accum};
         | 
| 1157 | 
            -
            }
         | 
| 1158 | 
            -
             | 
| 1159 | 
            -
            void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 1160 | 
            -
                #ifndef FLASHATTENTION_DISABLE_BACKWARD
         | 
| 1161 | 
            -
                    // FP16_SWITCH(!params.is_bf16, [&] {
         | 
| 1162 | 
            -
                    //     HEADDIM_SWITCH(params.d, [&] {
         | 
| 1163 | 
            -
                    //         run_mha_bwd_<elem_type, kHeadDim>(params, stream);
         | 
| 1164 | 
            -
                    //     });
         | 
| 1165 | 
            -
                    // });
         | 
| 1166 | 
            -
                ARCH_SWITCH(params.arch, Arch, [&] {
         | 
| 1167 | 
            -
                    SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
         | 
| 1168 | 
            -
                        if (!params.is_bf16) {
         | 
| 1169 | 
            -
                            #ifndef FLASHATTENTION_DISABLE_FP16
         | 
| 1170 | 
            -
                            #ifndef FLASHATTENTION_DISABLE_HDIM64
         | 
| 1171 | 
            -
                            if (params.d <= 64) { return run_mha_bwd_<Arch, cutlass::half_t, 64, Has_softcap>(params, stream); }
         | 
| 1172 | 
            -
                            #endif
         | 
| 1173 | 
            -
                            #ifndef FLASHATTENTION_DISABLE_HDIM96
         | 
| 1174 | 
            -
                            if (params.d <= 96) { return run_mha_bwd_<Arch, cutlass::half_t, 96, Has_softcap>(params, stream); }
         | 
| 1175 | 
            -
                            #endif
         | 
| 1176 | 
            -
                            #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 1177 | 
            -
                            if (params.d <= 128) { return run_mha_bwd_<Arch, cutlass::half_t, 128, Has_softcap>(params, stream); }
         | 
| 1178 | 
            -
                            #endif
         | 
| 1179 | 
            -
                            #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 1180 | 
            -
                            if (params.d <= 192) { return run_mha_bwd_<Arch, cutlass::half_t, 192, Has_softcap>(params, stream); }
         | 
| 1181 | 
            -
                            #endif
         | 
| 1182 | 
            -
                            #ifndef FLASHATTENTION_DISABLE_HDIM256
         | 
| 1183 | 
            -
                            if (params.d <= 256) { return run_mha_bwd_<Arch, cutlass::half_t, 256, Has_softcap>(params, stream); }
         | 
| 1184 | 
            -
                            #endif
         | 
| 1185 | 
            -
                            #else
         | 
| 1186 | 
            -
                            TORCH_CHECK(false, "This flash attention build does not support FP16.");
         | 
| 1187 | 
            -
                            #endif
         | 
| 1188 | 
            -
                        } else {
         | 
| 1189 | 
            -
                            #ifndef FLASHATTENTION_DISABLE_HDIM64
         | 
| 1190 | 
            -
                            if (params.d <= 64) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 64, Has_softcap>(params, stream); }
         | 
| 1191 | 
            -
                            #endif
         | 
| 1192 | 
            -
                            #ifndef FLASHATTENTION_DISABLE_HDIM96
         | 
| 1193 | 
            -
                            if (params.d <= 96) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 96, Has_softcap>(params, stream); }
         | 
| 1194 | 
            -
                            #endif
         | 
| 1195 | 
            -
                            #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 1196 | 
            -
                            if (params.d <= 128) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 128, Has_softcap>(params, stream); }
         | 
| 1197 | 
            -
                            #endif
         | 
| 1198 | 
            -
                            #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 1199 | 
            -
                            if (params.d <= 192) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 192, Has_softcap>(params, stream); }
         | 
| 1200 | 
            -
                            #endif
         | 
| 1201 | 
            -
                            #ifndef FLASHATTENTION_DISABLE_HDIM256
         | 
| 1202 | 
            -
                            if (params.d <= 256) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 256, Has_softcap>(params, stream); }
         | 
| 1203 | 
            -
                            #endif
         | 
| 1204 | 
            -
                        }
         | 
| 1205 | 
            -
                    });
         | 
| 1206 | 
            -
                });
         | 
| 1207 | 
            -
                #endif
         | 
| 1208 | 
            -
            }
         | 
| 1209 | 
            -
             | 
| 1210 | 
            -
             | 
| 1211 | 
            -
            // b: batch_size
         | 
| 1212 | 
            -
            // s_q: seqlen_q
         | 
| 1213 | 
            -
            // s_k: seqlen_k
         | 
| 1214 | 
            -
            // h: num_heads
         | 
| 1215 | 
            -
            // h_k: num_heads_k
         | 
| 1216 | 
            -
            // d: head_size
         | 
| 1217 | 
            -
            std::vector<at::Tensor> mha_bwd(
         | 
| 1218 | 
            -
                const at::Tensor &dout,  // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
         | 
| 1219 | 
            -
                const at::Tensor &q,     // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
         | 
| 1220 | 
            -
                const at::Tensor &k,     // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
         | 
| 1221 | 
            -
                const at::Tensor &v,     // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
         | 
| 1222 | 
            -
                const at::Tensor &out,   // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
         | 
| 1223 | 
            -
                const at::Tensor &softmax_lse,    // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q
         | 
| 1224 | 
            -
                std::optional<at::Tensor> &dq_,   // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
         | 
| 1225 | 
            -
                std::optional<at::Tensor> &dk_,   // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
         | 
| 1226 | 
            -
                std::optional<at::Tensor> &dv_,   // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
         | 
| 1227 | 
            -
                std::optional<const at::Tensor> &cu_seqlens_q_,   // b+1
         | 
| 1228 | 
            -
                std::optional<const at::Tensor> &cu_seqlens_k_,   // b+1
         | 
| 1229 | 
            -
                std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
         | 
| 1230 | 
            -
                std::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
         | 
| 1231 | 
            -
                std::optional<int> max_seqlen_q_,
         | 
| 1232 | 
            -
                std::optional<int> max_seqlen_k_,
         | 
| 1233 | 
            -
                float const softmax_scale,
         | 
| 1234 | 
            -
                bool is_causal,
         | 
| 1235 | 
            -
                int window_size_left,
         | 
| 1236 | 
            -
                int window_size_right,
         | 
| 1237 | 
            -
                float const softcap,
         | 
| 1238 | 
            -
                bool const deterministic,
         | 
| 1239 | 
            -
                int const sm_margin) {
         | 
| 1240 | 
            -
             | 
| 1241 | 
            -
                #ifdef FLASHATTENTION_DISABLE_BACKWARD
         | 
| 1242 | 
            -
                    TORCH_CHECK(false, "This flash attention build does not support backward.");
         | 
| 1243 | 
            -
                #endif
         | 
| 1244 | 
            -
             | 
| 1245 | 
            -
                auto dprops = at::cuda::getCurrentDeviceProperties();
         | 
| 1246 | 
            -
                bool is_sm8x = dprops->major >= 8;
         | 
| 1247 | 
            -
                TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
         | 
| 1248 | 
            -
             | 
| 1249 | 
            -
                auto q_type = q.dtype();
         | 
| 1250 | 
            -
                TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
         | 
| 1251 | 
            -
                            "FlashAttention only support fp16 and bf16 data type");
         | 
| 1252 | 
            -
                TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
         | 
| 1253 | 
            -
                TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
         | 
| 1254 | 
            -
                TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
         | 
| 1255 | 
            -
                TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
         | 
| 1256 | 
            -
             | 
| 1257 | 
            -
                CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
         | 
| 1258 | 
            -
                CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
         | 
| 1259 | 
            -
             | 
| 1260 | 
            -
                TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
         | 
| 1261 | 
            -
                TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
         | 
| 1262 | 
            -
                TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
         | 
| 1263 | 
            -
                TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
         | 
| 1264 | 
            -
                TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
         | 
| 1265 | 
            -
             | 
| 1266 | 
            -
                at::Tensor cu_seqlens_q;
         | 
| 1267 | 
            -
                bool const is_varlen_q = cu_seqlens_q_.has_value();
         | 
| 1268 | 
            -
                if (is_varlen_q) {
         | 
| 1269 | 
            -
                    cu_seqlens_q = cu_seqlens_q_.value();
         | 
| 1270 | 
            -
                    CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
         | 
| 1271 | 
            -
                    TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
         | 
| 1272 | 
            -
                    TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
         | 
| 1273 | 
            -
                }
         | 
| 1274 | 
            -
                at::Tensor cu_seqlens_k;
         | 
| 1275 | 
            -
                bool const is_varlen_k = cu_seqlens_k_.has_value();
         | 
| 1276 | 
            -
                if (is_varlen_k) {
         | 
| 1277 | 
            -
                    cu_seqlens_k = cu_seqlens_k_.value();
         | 
| 1278 | 
            -
                    CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
         | 
| 1279 | 
            -
                    TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
         | 
| 1280 | 
            -
                    TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
         | 
| 1281 | 
            -
                }
         | 
| 1282 | 
            -
                // This is what we will template on
         | 
| 1283 | 
            -
                bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value();
         | 
| 1284 | 
            -
                #ifdef FLASHATTENTION_DISABLE_VARLEN
         | 
| 1285 | 
            -
                    TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
         | 
| 1286 | 
            -
                #endif
         | 
| 1287 | 
            -
             | 
| 1288 | 
            -
                auto const sizes = q.sizes();
         | 
| 1289 | 
            -
                int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
         | 
| 1290 | 
            -
                int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
         | 
| 1291 | 
            -
                int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
         | 
| 1292 | 
            -
                int const num_heads = q.size(-2);
         | 
| 1293 | 
            -
                int const head_size = q.size(-1);
         | 
| 1294 | 
            -
                int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value();
         | 
| 1295 | 
            -
                int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
         | 
| 1296 | 
            -
                int const num_heads_k = k.size(-2);
         | 
| 1297 | 
            -
                TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
         | 
| 1298 | 
            -
                int const max_headdim = get_max_headdim();
         | 
| 1299 | 
            -
                TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
         | 
| 1300 | 
            -
                TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
         | 
| 1301 | 
            -
             | 
| 1302 | 
            -
                // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
         | 
| 1303 | 
            -
                if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
         | 
| 1304 | 
            -
                if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
         | 
| 1305 | 
            -
                if (is_causal) { window_size_right = 0; }
         | 
| 1306 | 
            -
                // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.
         | 
| 1307 | 
            -
                // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).
         | 
| 1308 | 
            -
                is_causal = window_size_left < 0 && window_size_right == 0;
         | 
| 1309 | 
            -
             | 
| 1310 | 
            -
                int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
         | 
| 1311 | 
            -
                int const head_size_rounded = round_up_headdim(head_size);
         | 
| 1312 | 
            -
                // Very important that these match the kernel configs
         | 
| 1313 | 
            -
                bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;
         | 
| 1314 | 
            -
                int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)
         | 
| 1315 | 
            -
                    : (head_size_rounded <= 96 ? 64
         | 
| 1316 | 
            -
                       : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)
         | 
| 1317 | 
            -
                          : 64));
         | 
| 1318 | 
            -
                int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
         | 
| 1319 | 
            -
                int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
         | 
| 1320 | 
            -
                int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
         | 
| 1321 | 
            -
                int const kBlockN_sm90 = head_size_rounded <= 128
         | 
| 1322 | 
            -
                    ? 128
         | 
| 1323 | 
            -
                    : (head_size_rounded <= 192 ? 96 : 80);
         | 
| 1324 | 
            -
                int const kBlockN_sm80 = head_size_rounded <= 128
         | 
| 1325 | 
            -
                    ? 128
         | 
| 1326 | 
            -
                    : (head_size_rounded <= 192 ? 80 : 64);
         | 
| 1327 | 
            -
                int const kBlockN_sm86 = head_size_rounded <= 64 ? 128
         | 
| 1328 | 
            -
                    : (head_size_rounded <= 96 ? 128
         | 
| 1329 | 
            -
                       : (head_size_rounded <= 128 ? 96
         | 
| 1330 | 
            -
                          : (head_size_rounded <= 192 ? 64 : 64)));
         | 
| 1331 | 
            -
                int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);
         | 
| 1332 | 
            -
                auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
         | 
| 1333 | 
            -
                int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
         | 
| 1334 | 
            -
                int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);
         | 
| 1335 | 
            -
                int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);
         | 
| 1336 | 
            -
                int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);
         | 
| 1337 | 
            -
             | 
| 1338 | 
            -
                if (!is_varlen_q) {
         | 
| 1339 | 
            -
                    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
         | 
| 1340 | 
            -
                    CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
         | 
| 1341 | 
            -
                    CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
         | 
| 1342 | 
            -
                } else {
         | 
| 1343 | 
            -
                    CHECK_SHAPE(q, total_q, num_heads, head_size);
         | 
| 1344 | 
            -
                    CHECK_SHAPE(out, total_q, num_heads, head_size);
         | 
| 1345 | 
            -
                    CHECK_SHAPE(dout, total_q, num_heads, head_size);
         | 
| 1346 | 
            -
                    CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
         | 
| 1347 | 
            -
                }
         | 
| 1348 | 
            -
                if (!is_varlen_k) {
         | 
| 1349 | 
            -
                    CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
         | 
| 1350 | 
            -
                    CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
         | 
| 1351 | 
            -
                } else {
         | 
| 1352 | 
            -
                    CHECK_SHAPE(k, total_k, num_heads_k, head_size);
         | 
| 1353 | 
            -
                    CHECK_SHAPE(v, total_k, num_heads_k, head_size);
         | 
| 1354 | 
            -
                    CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
         | 
| 1355 | 
            -
                }
         | 
| 1356 | 
            -
             | 
| 1357 | 
            -
                if (seqused_q_.has_value()){
         | 
| 1358 | 
            -
                    auto seqused_q = seqused_q_.value();
         | 
| 1359 | 
            -
                    TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
         | 
| 1360 | 
            -
                    CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
         | 
| 1361 | 
            -
                    CHECK_SHAPE(seqused_q, batch_size);
         | 
| 1362 | 
            -
                }
         | 
| 1363 | 
            -
                if (seqused_k_.has_value()){
         | 
| 1364 | 
            -
                    auto seqused_k = seqused_k_.value();
         | 
| 1365 | 
            -
                    TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
         | 
| 1366 | 
            -
                    CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
         | 
| 1367 | 
            -
                    CHECK_SHAPE(seqused_k, batch_size);
         | 
| 1368 | 
            -
                }
         | 
| 1369 | 
            -
             | 
| 1370 | 
            -
                at::Tensor dq, dk, dv;
         | 
| 1371 | 
            -
                if (dq_.has_value()) {
         | 
| 1372 | 
            -
                    dq = dq_.value();
         | 
| 1373 | 
            -
                    TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
         | 
| 1374 | 
            -
                    CHECK_DEVICE(dq);
         | 
| 1375 | 
            -
                    TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
         | 
| 1376 | 
            -
                    if (!is_varlen_q) {
         | 
| 1377 | 
            -
                        CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
         | 
| 1378 | 
            -
                    } else {
         | 
| 1379 | 
            -
                        CHECK_SHAPE(dq, total_q, num_heads, head_size);
         | 
| 1380 | 
            -
                    }
         | 
| 1381 | 
            -
                } else {
         | 
| 1382 | 
            -
                    dq = torch::empty_like(q);
         | 
| 1383 | 
            -
                }
         | 
| 1384 | 
            -
                if (dk_.has_value()) {
         | 
| 1385 | 
            -
                    dk = dk_.value();
         | 
| 1386 | 
            -
                    TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
         | 
| 1387 | 
            -
                    CHECK_DEVICE(dk);
         | 
| 1388 | 
            -
                    TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
         | 
| 1389 | 
            -
                    if (!is_varlen_k) {
         | 
| 1390 | 
            -
                        CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
         | 
| 1391 | 
            -
                    } else {
         | 
| 1392 | 
            -
                        CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
         | 
| 1393 | 
            -
                    }
         | 
| 1394 | 
            -
                } else {
         | 
| 1395 | 
            -
                    dk = torch::empty_like(k);
         | 
| 1396 | 
            -
                }
         | 
| 1397 | 
            -
                if (dv_.has_value()) {
         | 
| 1398 | 
            -
                    dv = dv_.value();
         | 
| 1399 | 
            -
                    TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
         | 
| 1400 | 
            -
                    CHECK_DEVICE(dv);
         | 
| 1401 | 
            -
                    TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
         | 
| 1402 | 
            -
                    if (!is_varlen_k) {
         | 
| 1403 | 
            -
                        CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
         | 
| 1404 | 
            -
                    } else {
         | 
| 1405 | 
            -
                        CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
         | 
| 1406 | 
            -
                    }
         | 
| 1407 | 
            -
                } else {
         | 
| 1408 | 
            -
                    dv = torch::empty_like(v);
         | 
| 1409 | 
            -
                }
         | 
| 1410 | 
            -
             | 
| 1411 | 
            -
                // Otherwise the kernel will be launched from cuda:0 device
         | 
| 1412 | 
            -
                // Cast to char to avoid compiler warning about narrowing
         | 
| 1413 | 
            -
                at::cuda::CUDAGuard device_guard{(char)q.get_device()};
         | 
| 1414 | 
            -
             | 
| 1415 | 
            -
                auto opts = q.options();
         | 
| 1416 | 
            -
                // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
         | 
| 1417 | 
            -
                at::Tensor softmax_d, softmax_lse_log2;
         | 
| 1418 | 
            -
                if (!is_varlen) {
         | 
| 1419 | 
            -
                    // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
         | 
| 1420 | 
            -
                    softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
         | 
| 1421 | 
            -
                    softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
         | 
| 1422 | 
            -
                } else {
         | 
| 1423 | 
            -
                    softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
         | 
| 1424 | 
            -
                    softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
         | 
| 1425 | 
            -
                }
         | 
| 1426 | 
            -
                at::Tensor dq_accum, dk_accum, dv_accum;
         | 
| 1427 | 
            -
                if (!is_varlen) {
         | 
| 1428 | 
            -
                    dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat));
         | 
| 1429 | 
            -
                } else {
         | 
| 1430 | 
            -
                    dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat));
         | 
| 1431 | 
            -
                }
         | 
| 1432 | 
            -
                if (num_heads_k != num_heads) {  // MQA / GQA
         | 
| 1433 | 
            -
                    if (!is_varlen) {
         | 
| 1434 | 
            -
                        dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
         | 
| 1435 | 
            -
                        dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
         | 
| 1436 | 
            -
                    } else {
         | 
| 1437 | 
            -
                        dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
         | 
| 1438 | 
            -
                        dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
         | 
| 1439 | 
            -
                    }
         | 
| 1440 | 
            -
                }
         | 
| 1441 | 
            -
             | 
| 1442 | 
            -
                Flash_bwd_params params;
         | 
| 1443 | 
            -
                set_params_dgrad(params,
         | 
| 1444 | 
            -
                                 batch_size,
         | 
| 1445 | 
            -
                                 seqlen_q, seqlen_k,
         | 
| 1446 | 
            -
                                 seqlen_q_rounded, seqlen_k_rounded,
         | 
| 1447 | 
            -
                                 num_heads, num_heads_k,
         | 
| 1448 | 
            -
                                 head_size, head_size_rounded,
         | 
| 1449 | 
            -
                                 q, k, v, out,
         | 
| 1450 | 
            -
                                 dout, dq, dk, dv,
         | 
| 1451 | 
            -
                                 !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
         | 
| 1452 | 
            -
                                 !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
         | 
| 1453 | 
            -
                                 seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
         | 
| 1454 | 
            -
                                 seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
         | 
| 1455 | 
            -
                                 dq_accum.data_ptr(),
         | 
| 1456 | 
            -
                                 num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
         | 
| 1457 | 
            -
                                 num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
         | 
| 1458 | 
            -
                                 softmax_lse.data_ptr(),
         | 
| 1459 | 
            -
                                 softmax_d.data_ptr(),
         | 
| 1460 | 
            -
                                 /*p_dropout=*/0.f,
         | 
| 1461 | 
            -
                                 softmax_scale,
         | 
| 1462 | 
            -
                                 window_size_left,
         | 
| 1463 | 
            -
                                 window_size_right,
         | 
| 1464 | 
            -
                                 softcap,
         | 
| 1465 | 
            -
                                 deterministic,
         | 
| 1466 | 
            -
                                 sm_margin);
         | 
| 1467 | 
            -
                params.total_q = total_q;
         | 
| 1468 | 
            -
                params.total_k = total_k;
         | 
| 1469 | 
            -
                params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
         | 
| 1470 | 
            -
                params.dv = head_size;  // We don't support hdim_v being different from hdim_qk for now
         | 
| 1471 | 
            -
             | 
| 1472 | 
            -
                // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
         | 
| 1473 | 
            -
                // params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
         | 
| 1474 | 
            -
                // Will be zero'ed out in the backward preprocess kernel
         | 
| 1475 | 
            -
                at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
         | 
| 1476 | 
            -
                params.dq_semaphore = dq_semaphore.data_ptr<int>();
         | 
| 1477 | 
            -
                if (num_heads_k != num_heads && params.deterministic) {
         | 
| 1478 | 
            -
                    // TODO: do we need to zero them out?
         | 
| 1479 | 
            -
                    at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
         | 
| 1480 | 
            -
                    at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
         | 
| 1481 | 
            -
                    params.dk_semaphore = dk_semaphore.data_ptr<int>();
         | 
| 1482 | 
            -
                    params.dv_semaphore = dv_semaphore.data_ptr<int>();
         | 
| 1483 | 
            -
                }
         | 
| 1484 | 
            -
             | 
| 1485 | 
            -
                #ifdef FLASHATTENTION_DISABLE_LOCAL
         | 
| 1486 | 
            -
                TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
         | 
| 1487 | 
            -
                #endif
         | 
| 1488 | 
            -
                #ifdef FLASHATTENTION_DISABLE_SOFTCAP
         | 
| 1489 | 
            -
                TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
         | 
| 1490 | 
            -
                #endif
         | 
| 1491 | 
            -
             | 
| 1492 | 
            -
                if (total_q > 0 && total_k > 0 && num_heads_k > 0) {
         | 
| 1493 | 
            -
                    auto stream = at::cuda::getCurrentCUDAStream().stream();
         | 
| 1494 | 
            -
                    run_mha_bwd(params, stream);
         | 
| 1495 | 
            -
                } else if (total_k > 0 && num_heads_k > 0) {
         | 
| 1496 | 
            -
                    // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
         | 
| 1497 | 
            -
                    dk.zero_();
         | 
| 1498 | 
            -
                    dv.zero_();
         | 
| 1499 | 
            -
                    softmax_d.zero_();
         | 
| 1500 | 
            -
                } else if (total_q > 0 && num_heads_k > 0) {
         | 
| 1501 | 
            -
                    dq.zero_();
         | 
| 1502 | 
            -
                    softmax_d.zero_();
         | 
| 1503 | 
            -
                }
         | 
| 1504 | 
            -
             | 
| 1505 | 
            -
                return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
         | 
| 1506 | 
            -
            }
         | 
| 1507 | 
            -
             | 
| 1508 | 
            -
            std::vector<at::Tensor>
         | 
| 1509 | 
            -
            mha_combine(const at::Tensor &out_partial,         // num_splits x batch_size x seqlen x num_heads x head_size
         | 
| 1510 | 
            -
                        const at::Tensor &lse_partial,         // num_splits x batch_size x seqlen x num_heads
         | 
| 1511 | 
            -
                        std::optional<at::Tensor> out_,        // batch_size x seqlen x num_heads x head_size
         | 
| 1512 | 
            -
                        std::optional<at::ScalarType> out_dtype_
         | 
| 1513 | 
            -
                        ) {
         | 
| 1514 | 
            -
             | 
| 1515 | 
            -
                auto dprops = at::cuda::getCurrentDeviceProperties();
         | 
| 1516 | 
            -
                bool is_sm8x = dprops->major >= 8;
         | 
| 1517 | 
            -
                TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer.");
         | 
| 1518 | 
            -
             | 
| 1519 | 
            -
                auto out_partial_type = out_partial.scalar_type();
         | 
| 1520 | 
            -
                TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type");
         | 
| 1521 | 
            -
                TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type");
         | 
| 1522 | 
            -
             | 
| 1523 | 
            -
                CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial);
         | 
| 1524 | 
            -
             | 
| 1525 | 
            -
                TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension");
         | 
| 1526 | 
            -
                TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension");
         | 
| 1527 | 
            -
             | 
| 1528 | 
            -
                const auto sizes = out_partial.sizes();
         | 
| 1529 | 
            -
             | 
| 1530 | 
            -
                const int num_splits = sizes[0];
         | 
| 1531 | 
            -
                const int batch_size = sizes[1];
         | 
| 1532 | 
            -
                const int seqlen = sizes[2];
         | 
| 1533 | 
            -
                const int num_heads = sizes[3];
         | 
| 1534 | 
            -
                const int head_size_og = sizes[4];
         | 
| 1535 | 
            -
                TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256");
         | 
| 1536 | 
            -
             | 
| 1537 | 
            -
                CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);
         | 
| 1538 | 
            -
                CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads);
         | 
| 1539 | 
            -
             | 
| 1540 | 
            -
                int const alignment = 4;
         | 
| 1541 | 
            -
                at::Tensor out_partial_padded;
         | 
| 1542 | 
            -
                auto pad = [](at::Tensor x, int alignment) {
         | 
| 1543 | 
            -
                    return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
         | 
| 1544 | 
            -
                };
         | 
| 1545 | 
            -
                out_partial_padded = pad(out_partial, alignment);
         | 
| 1546 | 
            -
             | 
| 1547 | 
            -
                auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
         | 
| 1548 | 
            -
                const int head_size = round_multiple(head_size_og, alignment);
         | 
| 1549 | 
            -
             | 
| 1550 | 
            -
                auto opts = out_partial.options();
         | 
| 1551 | 
            -
                at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type());
         | 
| 1552 | 
            -
                TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16");
         | 
| 1553 | 
            -
                at::Tensor out;
         | 
| 1554 | 
            -
                if (out_.has_value()) {
         | 
| 1555 | 
            -
                    out = out_.value();
         | 
| 1556 | 
            -
                    TORCH_CHECK(out.scalar_type() == out_type);
         | 
| 1557 | 
            -
                    CHECK_DEVICE(out);
         | 
| 1558 | 
            -
                    TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
         | 
| 1559 | 
            -
                    CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og);
         | 
| 1560 | 
            -
                    if (head_size_og % alignment != 0) {
         | 
| 1561 | 
            -
                        out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
         | 
| 1562 | 
            -
                    }
         | 
| 1563 | 
            -
                } else {
         | 
| 1564 | 
            -
                    out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
         | 
| 1565 | 
            -
                }
         | 
| 1566 | 
            -
             | 
| 1567 | 
            -
                // Otherwise the kernel will be launched from cuda:0 device
         | 
| 1568 | 
            -
                // Cast to char to avoid compiler warning about narrowing
         | 
| 1569 | 
            -
                at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()};
         | 
| 1570 | 
            -
             | 
| 1571 | 
            -
                auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2);
         | 
| 1572 | 
            -
             | 
| 1573 | 
            -
                Flash_fwd_params params {};  // Need to reset the params to set everything to zero
         | 
| 1574 | 
            -
                params.is_fp32 = out_type == at::ScalarType::Float;
         | 
| 1575 | 
            -
                params.is_bf16 = out_type == at::ScalarType::BFloat16;
         | 
| 1576 | 
            -
                params.oaccum_ptr = out_partial_padded.data_ptr();
         | 
| 1577 | 
            -
                params.softmax_lseaccum_ptr = lse_partial.data_ptr();
         | 
| 1578 | 
            -
                params.o_ptr = out.data_ptr();
         | 
| 1579 | 
            -
                params.softmax_lse_ptr = softmax_lse.data_ptr();
         | 
| 1580 | 
            -
                params.b = batch_size;
         | 
| 1581 | 
            -
                params.h = num_heads;
         | 
| 1582 | 
            -
                params.seqlen_q = seqlen;
         | 
| 1583 | 
            -
                params.dv = head_size;
         | 
| 1584 | 
            -
                params.num_splits = num_splits;
         | 
| 1585 | 
            -
                params.oaccum_split_stride = out_partial_padded.stride(0);
         | 
| 1586 | 
            -
                params.oaccum_row_stride = out_partial_padded.stride(2);
         | 
| 1587 | 
            -
                params.oaccum_head_stride = out_partial_padded.stride(3);
         | 
| 1588 | 
            -
                params.oaccum_batch_stride = out_partial_padded.stride(1);
         | 
| 1589 | 
            -
                params.lseaccum_split_stride = lse_partial.stride(0);
         | 
| 1590 | 
            -
                params.lseaccum_head_stride = lse_partial.stride(3);
         | 
| 1591 | 
            -
                params.lseaccum_batch_stride = lse_partial.stride(1);
         | 
| 1592 | 
            -
                params.o_row_stride = out.stride(1);
         | 
| 1593 | 
            -
                params.o_head_stride = out.stride(2);
         | 
| 1594 | 
            -
                params.o_batch_stride = out.stride(0);
         | 
| 1595 | 
            -
                params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
         | 
| 1596 | 
            -
             | 
| 1597 | 
            -
                if (seqlen > 0 && batch_size > 0) {
         | 
| 1598 | 
            -
                    auto stream = at::cuda::getCurrentCUDAStream().stream();
         | 
| 1599 | 
            -
                    run_mha_fwd_combine(params, stream, false /*enable_pdl*/);
         | 
| 1600 | 
            -
                }
         | 
| 1601 | 
            -
             | 
| 1602 | 
            -
                at::Tensor out_padded = out;
         | 
| 1603 | 
            -
                if (head_size_og % alignment != 0) {
         | 
| 1604 | 
            -
                    out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
         | 
| 1605 | 
            -
                    // if (out_.has_value()) { out_.value().copy_(out); }
         | 
| 1606 | 
            -
                }
         | 
| 1607 | 
            -
             | 
| 1608 | 
            -
                return {out, softmax_lse};
         | 
| 1609 | 
            -
            }
         | 
| 1610 | 
            -
             | 
| 1611 | 
            -
            #ifndef FLASHATTENTION_DISABLE_PYBIND
         | 
| 1612 | 
            -
             | 
| 1613 | 
            -
            #include <torch/python.h>
         | 
| 1614 | 
            -
             | 
| 1615 | 
            -
            PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
         | 
| 1616 | 
            -
                m.doc() = "FlashAttention";
         | 
| 1617 | 
            -
                m.def("fwd", &mha_fwd, "Forward pass");
         | 
| 1618 | 
            -
                m.def("bwd", &mha_bwd, "Backward pass");
         | 
| 1619 | 
            -
                m.def("fwd_combine", &mha_combine, "Combine partial attention outputs");
         | 
| 1620 | 
            -
                m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass");
         | 
| 1621 | 
            -
            }
         | 
| 1622 | 
            -
             | 
| 1623 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash_bwd_kernel_sm80.h
    DELETED
    
    | @@ -1,173 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include "cute/tensor.hpp"
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            #include <cutlass/cutlass.h>
         | 
| 10 | 
            -
            #include <cutlass/array.h>
         | 
| 11 | 
            -
            #include <cutlass/numeric_types.h>
         | 
| 12 | 
            -
            #include <cutlass/kernel_hardware_info.h>
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            #include "utils.h"
         | 
| 15 | 
            -
             | 
| 16 | 
            -
            namespace flash {
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            using namespace cute;
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
         | 
| 21 | 
            -
            class FlashAttnBwdSm80 {
         | 
| 22 | 
            -
             | 
| 23 | 
            -
            public:
         | 
| 24 | 
            -
             | 
| 25 | 
            -
                // Type Aliases
         | 
| 26 | 
            -
                static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
         | 
| 27 | 
            -
                static constexpr bool Is_local = CollectiveMainloop_::Is_local;
         | 
| 28 | 
            -
                static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
         | 
| 29 | 
            -
                static constexpr bool Varlen = CollectiveMainloop_::Varlen;
         | 
| 30 | 
            -
             | 
| 31 | 
            -
                // Mainloop derived types
         | 
| 32 | 
            -
                using CollectiveMainloop = CollectiveMainloop_;
         | 
| 33 | 
            -
                using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
         | 
| 34 | 
            -
                using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
         | 
| 35 | 
            -
                using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
         | 
| 36 | 
            -
                using ArchTag = typename CollectiveMainloop::ArchTag;
         | 
| 37 | 
            -
                using MainloopArguments = typename CollectiveMainloop::Arguments;
         | 
| 38 | 
            -
                using MainloopParams = typename CollectiveMainloop::Params;
         | 
| 39 | 
            -
                static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
         | 
| 40 | 
            -
             | 
| 41 | 
            -
                // Epilogue derived types
         | 
| 42 | 
            -
                using CollectiveEpilogue = CollectiveEpilogue_;
         | 
| 43 | 
            -
                using EpilogueArguments = typename CollectiveEpilogue::Arguments;
         | 
| 44 | 
            -
                using EpilogueParams = typename CollectiveEpilogue::Params;
         | 
| 45 | 
            -
             | 
| 46 | 
            -
                static_assert(ArchTag::kMinComputeCapability >= 80);
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                using TileScheduler = TileScheduler_;
         | 
| 49 | 
            -
                using TileSchedulerArguments = typename flash::TileSchedulerArguments;
         | 
| 50 | 
            -
                using TileSchedulerParams = typename TileScheduler::Params;
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{}));
         | 
| 53 | 
            -
                static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{}));
         | 
| 54 | 
            -
                static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
         | 
| 55 | 
            -
             | 
| 56 | 
            -
                // Kernel level shared memory storage
         | 
| 57 | 
            -
                struct SharedStorage {
         | 
| 58 | 
            -
                    struct TensorStorage : cute::aligned_struct<128> {
         | 
| 59 | 
            -
                        union {
         | 
| 60 | 
            -
                            typename CollectiveMainloop::TensorStorage mainloop;
         | 
| 61 | 
            -
                            typename CollectiveEpilogue::TensorStorage epilogue;
         | 
| 62 | 
            -
                        };
         | 
| 63 | 
            -
                    } tensors;
         | 
| 64 | 
            -
             | 
| 65 | 
            -
                    alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
         | 
| 66 | 
            -
             | 
| 67 | 
            -
                };
         | 
| 68 | 
            -
             | 
| 69 | 
            -
                static constexpr int SharedStorageSize = sizeof(SharedStorage);
         | 
| 70 | 
            -
             | 
| 71 | 
            -
                // Device side arguments
         | 
| 72 | 
            -
                struct Arguments {
         | 
| 73 | 
            -
                    MainloopArguments mainloop{};
         | 
| 74 | 
            -
                    EpilogueArguments epilogue{};
         | 
| 75 | 
            -
                    cutlass::KernelHardwareInfo hw_info{};
         | 
| 76 | 
            -
                    TileSchedulerArguments scheduler{};
         | 
| 77 | 
            -
                };
         | 
| 78 | 
            -
             | 
| 79 | 
            -
                // Kernel entry point API
         | 
| 80 | 
            -
                struct Params {
         | 
| 81 | 
            -
                    MainloopParams mainloop{};
         | 
| 82 | 
            -
                    EpilogueParams epilogue{};
         | 
| 83 | 
            -
                    cutlass::KernelHardwareInfo hw_info{};
         | 
| 84 | 
            -
                    TileSchedulerParams scheduler{};
         | 
| 85 | 
            -
                };
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                //
         | 
| 88 | 
            -
                // Methods
         | 
| 89 | 
            -
                //
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                // Convert to underlying arguments. In this case, a simple copy for the aliased type.
         | 
| 92 | 
            -
                static
         | 
| 93 | 
            -
                Params
         | 
| 94 | 
            -
                to_underlying_arguments(Arguments const& args) {
         | 
| 95 | 
            -
                    CUTLASS_TRACE_HOST("to_underlying_arguments():");
         | 
| 96 | 
            -
             | 
| 97 | 
            -
                    // Get SM count if needed, otherwise use user supplied SM count
         | 
| 98 | 
            -
                    int sm_count = args.hw_info.sm_count;
         | 
| 99 | 
            -
                    if (sm_count <= 0) {
         | 
| 100 | 
            -
                        CUTLASS_TRACE_HOST("  WARNING: Arguments do not include a valid SM count.\n"
         | 
| 101 | 
            -
                            "  For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
         | 
| 102 | 
            -
                        sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
         | 
| 103 | 
            -
                    }
         | 
| 104 | 
            -
             | 
| 105 | 
            -
                    CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                    cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
         | 
| 108 | 
            -
                    return {
         | 
| 109 | 
            -
                        CollectiveMainloop::to_underlying_arguments(args.mainloop),
         | 
| 110 | 
            -
                        CollectiveEpilogue::to_underlying_arguments(args.epilogue),
         | 
| 111 | 
            -
                        hw_info,
         | 
| 112 | 
            -
                        TileScheduler::to_underlying_arguments(args.scheduler)
         | 
| 113 | 
            -
                    };
         | 
| 114 | 
            -
                }
         | 
| 115 | 
            -
             | 
| 116 | 
            -
                // Computes the kernel launch grid shape based on runtime parameters
         | 
| 117 | 
            -
                static dim3
         | 
| 118 | 
            -
                get_grid_shape(Params const& params) {
         | 
| 119 | 
            -
                    return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
         | 
| 120 | 
            -
                }
         | 
| 121 | 
            -
             | 
| 122 | 
            -
                static dim3
         | 
| 123 | 
            -
                get_block_shape() {
         | 
| 124 | 
            -
                    return dim3(MaxThreadsPerBlock, 1, 1);
         | 
| 125 | 
            -
                }
         | 
| 126 | 
            -
             | 
| 127 | 
            -
                CUTLASS_DEVICE
         | 
| 128 | 
            -
                void
         | 
| 129 | 
            -
                operator()(Params const& params, char* smem_buf) {
         | 
| 130 | 
            -
             | 
| 131 | 
            -
                    static constexpr int kBlockM = get<0>(TileShape_MNK{});
         | 
| 132 | 
            -
                    static constexpr int kBlockN = get<1>(TileShape_MNK{});
         | 
| 133 | 
            -
             | 
| 134 | 
            -
                    SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
         | 
| 135 | 
            -
             | 
| 136 | 
            -
                    CollectiveMainloop mainloop;
         | 
| 137 | 
            -
                    CollectiveEpilogue epilogue;
         | 
| 138 | 
            -
             | 
| 139 | 
            -
                    TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
         | 
| 140 | 
            -
                    // Initialize matmul objects.
         | 
| 141 | 
            -
                    TiledMmadKV tiled_mma_dKV;
         | 
| 142 | 
            -
             | 
| 143 | 
            -
                    scheduler.init_consumer();
         | 
| 144 | 
            -
             | 
| 145 | 
            -
                    int warp_idx = cutlass::canonical_warp_idx_sync();
         | 
| 146 | 
            -
                    CUTLASS_PRAGMA_NO_UNROLL
         | 
| 147 | 
            -
                    for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
         | 
| 148 | 
            -
                         work_tile_info.is_valid(params.scheduler);
         | 
| 149 | 
            -
                         work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
         | 
| 150 | 
            -
             | 
| 151 | 
            -
                        auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
         | 
| 152 | 
            -
                        auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
         | 
| 153 | 
            -
                        cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
         | 
| 154 | 
            -
             | 
| 155 | 
            -
                        // dK and dV output accumulator.
         | 
| 156 | 
            -
                        Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
         | 
| 157 | 
            -
                        Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
         | 
| 158 | 
            -
                        bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x,
         | 
| 159 | 
            -
                                                       block_coord, shared_storage);
         | 
| 160 | 
            -
                        scheduler.prefetch_next_work(params.scheduler, work_tile_info);
         | 
| 161 | 
            -
                        if (tile_valid) {
         | 
| 162 | 
            -
                            epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
         | 
| 163 | 
            -
                                           threadIdx.x, block_coord);
         | 
| 164 | 
            -
                        } else {
         | 
| 165 | 
            -
                            epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
         | 
| 166 | 
            -
                        }
         | 
| 167 | 
            -
                    }
         | 
| 168 | 
            -
             | 
| 169 | 
            -
                }
         | 
| 170 | 
            -
             | 
| 171 | 
            -
            };
         | 
| 172 | 
            -
             | 
| 173 | 
            -
            } // namespace flash
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash_bwd_kernel_sm90.h
    DELETED
    
    | @@ -1,282 +0,0 @@ | |
| 1 | 
            -
             | 
| 2 | 
            -
            /******************************************************************************
         | 
| 3 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 4 | 
            -
             ******************************************************************************/
         | 
| 5 | 
            -
             | 
| 6 | 
            -
            #pragma once
         | 
| 7 | 
            -
             | 
| 8 | 
            -
            #include "cute/tensor.hpp"
         | 
| 9 | 
            -
             | 
| 10 | 
            -
            #include <cutlass/cutlass.h>
         | 
| 11 | 
            -
            #include <cutlass/arch/reg_reconfig.h>
         | 
| 12 | 
            -
            #include <cutlass/array.h>
         | 
| 13 | 
            -
            #include <cutlass/numeric_types.h>
         | 
| 14 | 
            -
            #include <cutlass/numeric_conversion.h>
         | 
| 15 | 
            -
            #include <cutlass/kernel_hardware_info.h>
         | 
| 16 | 
            -
            #include "cutlass/pipeline/pipeline.hpp"
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            #include "utils.h"
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            namespace flash {
         | 
| 21 | 
            -
             | 
| 22 | 
            -
            using namespace cute;
         | 
| 23 | 
            -
             | 
| 24 | 
            -
            template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
         | 
| 25 | 
            -
            class FlashAttnBwdSm90 {
         | 
| 26 | 
            -
             | 
| 27 | 
            -
            public:
         | 
| 28 | 
            -
             | 
| 29 | 
            -
                // Type Aliases
         | 
| 30 | 
            -
                static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
         | 
| 31 | 
            -
                static constexpr bool Is_local = CollectiveMainloop_::Is_local;
         | 
| 32 | 
            -
                static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
         | 
| 33 | 
            -
                static constexpr bool Varlen = CollectiveMainloop_::Varlen;
         | 
| 34 | 
            -
             | 
| 35 | 
            -
                // Mainloop derived types
         | 
| 36 | 
            -
                using CollectiveMainloop = CollectiveMainloop_;
         | 
| 37 | 
            -
                using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
         | 
| 38 | 
            -
                using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
         | 
| 39 | 
            -
                using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
         | 
| 40 | 
            -
                using ArchTag = typename CollectiveMainloop::ArchTag;
         | 
| 41 | 
            -
                using ClusterShape = typename CollectiveMainloop::ClusterShape;
         | 
| 42 | 
            -
                using MainloopArguments = typename CollectiveMainloop::Arguments;
         | 
| 43 | 
            -
                using MainloopParams = typename CollectiveMainloop::Params;
         | 
| 44 | 
            -
                static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
         | 
| 45 | 
            -
             | 
| 46 | 
            -
                // Epilogue derived types
         | 
| 47 | 
            -
                using CollectiveEpilogue = CollectiveEpilogue_;
         | 
| 48 | 
            -
                using EpilogueArguments = typename CollectiveEpilogue::Arguments;
         | 
| 49 | 
            -
                using EpilogueParams = typename CollectiveEpilogue::Params;
         | 
| 50 | 
            -
             | 
| 51 | 
            -
                static_assert(ArchTag::kMinComputeCapability >= 90);
         | 
| 52 | 
            -
             | 
| 53 | 
            -
                using TileScheduler = TileScheduler_;
         | 
| 54 | 
            -
                using TileSchedulerArguments = typename flash::TileSchedulerArguments;
         | 
| 55 | 
            -
                using TileSchedulerParams = typename TileScheduler::Params;
         | 
| 56 | 
            -
             | 
| 57 | 
            -
                static constexpr uint32_t NumLoadWarpGroups = 1;
         | 
| 58 | 
            -
                static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup;
         | 
| 59 | 
            -
                static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
         | 
| 60 | 
            -
                static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
         | 
| 61 | 
            -
                static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                /// Register requirement for Load and Math WGs
         | 
| 64 | 
            -
                static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32;
         | 
| 65 | 
            -
                static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160;
         | 
| 66 | 
            -
                // If you want to print from the producer warp, you'd need to increase the number of registers
         | 
| 67 | 
            -
                // Otherwise you'll get CUDA error.
         | 
| 68 | 
            -
                // static constexpr uint32_t LoadRegisterRequirement = 40;
         | 
| 69 | 
            -
                // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
         | 
| 70 | 
            -
             | 
| 71 | 
            -
                // Kernel level shared memory storage
         | 
| 72 | 
            -
                struct SharedStorage {
         | 
| 73 | 
            -
                    struct TensorStorage : cute::aligned_struct<128> {
         | 
| 74 | 
            -
                        union {
         | 
| 75 | 
            -
                            typename CollectiveMainloop::TensorStorage mainloop;
         | 
| 76 | 
            -
                            typename CollectiveEpilogue::TensorStorage epilogue;
         | 
| 77 | 
            -
                        };
         | 
| 78 | 
            -
                    } tensors;
         | 
| 79 | 
            -
             | 
| 80 | 
            -
                    struct PipelineStorage : cute::aligned_struct<16> {
         | 
| 81 | 
            -
                        alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV;
         | 
| 82 | 
            -
                        alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q;
         | 
| 83 | 
            -
                        alignas(16) typename CollectiveMainloop::MainloopPipeline_dO::SharedStorage pipeline_do;
         | 
| 84 | 
            -
                        alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
         | 
| 85 | 
            -
                    } pipelines;
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                };
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                static constexpr int SharedStorageSize = sizeof(SharedStorage);
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                // Device side arguments
         | 
| 92 | 
            -
                struct Arguments {
         | 
| 93 | 
            -
                    MainloopArguments mainloop{};
         | 
| 94 | 
            -
                    EpilogueArguments epilogue{};
         | 
| 95 | 
            -
                    cutlass::KernelHardwareInfo hw_info{};
         | 
| 96 | 
            -
                    TileSchedulerArguments scheduler{};
         | 
| 97 | 
            -
                };
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                // Kernel entry point API
         | 
| 100 | 
            -
                struct Params {
         | 
| 101 | 
            -
                    MainloopParams mainloop{};
         | 
| 102 | 
            -
                    EpilogueParams epilogue{};
         | 
| 103 | 
            -
                    cutlass::KernelHardwareInfo hw_info{};
         | 
| 104 | 
            -
                    TileSchedulerParams scheduler{};
         | 
| 105 | 
            -
                };
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                //
         | 
| 108 | 
            -
                // Methods
         | 
| 109 | 
            -
                //
         | 
| 110 | 
            -
             | 
| 111 | 
            -
                // Convert to underlying arguments. In this case, a simple copy for the aliased type.
         | 
| 112 | 
            -
                static
         | 
| 113 | 
            -
                Params
         | 
| 114 | 
            -
                to_underlying_arguments(Arguments const& args) {
         | 
| 115 | 
            -
                    CUTLASS_TRACE_HOST("to_underlying_arguments():");
         | 
| 116 | 
            -
             | 
| 117 | 
            -
                    // Get SM count if needed, otherwise use user supplied SM count
         | 
| 118 | 
            -
                    int sm_count = args.hw_info.sm_count;
         | 
| 119 | 
            -
                    if (sm_count <= 0) {
         | 
| 120 | 
            -
                        CUTLASS_TRACE_HOST("  WARNING: Arguments do not include a valid SM count.\n"
         | 
| 121 | 
            -
                            "  For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
         | 
| 122 | 
            -
                        sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
         | 
| 123 | 
            -
                    }
         | 
| 124 | 
            -
             | 
| 125 | 
            -
                    CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
         | 
| 126 | 
            -
             | 
| 127 | 
            -
                    cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
         | 
| 128 | 
            -
                    return {
         | 
| 129 | 
            -
                        CollectiveMainloop::to_underlying_arguments(args.mainloop),
         | 
| 130 | 
            -
                        CollectiveEpilogue::to_underlying_arguments(args.epilogue),
         | 
| 131 | 
            -
                        hw_info,
         | 
| 132 | 
            -
                        TileScheduler::to_underlying_arguments(args.scheduler)
         | 
| 133 | 
            -
                    };
         | 
| 134 | 
            -
                }
         | 
| 135 | 
            -
             | 
| 136 | 
            -
                // Computes the kernel launch grid shape based on runtime parameters
         | 
| 137 | 
            -
                static dim3
         | 
| 138 | 
            -
                get_grid_shape(Params const& params) {
         | 
| 139 | 
            -
                    return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
         | 
| 140 | 
            -
                }
         | 
| 141 | 
            -
             | 
| 142 | 
            -
                static dim3
         | 
| 143 | 
            -
                get_block_shape() {
         | 
| 144 | 
            -
                    return dim3(MaxThreadsPerBlock, 1, 1);
         | 
| 145 | 
            -
                }
         | 
| 146 | 
            -
             | 
| 147 | 
            -
                CUTLASS_DEVICE
         | 
| 148 | 
            -
                void
         | 
| 149 | 
            -
                operator()(Params const& params, char* smem_buf) {
         | 
| 150 | 
            -
             | 
| 151 | 
            -
                    static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
         | 
| 152 | 
            -
                    static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
         | 
| 153 | 
            -
                    static constexpr int kBlockM = get<0>(TileShape_MNK{});
         | 
| 154 | 
            -
                    static constexpr int kBlockN = get<1>(TileShape_MNK{});
         | 
| 155 | 
            -
             | 
| 156 | 
            -
                    using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
         | 
| 157 | 
            -
                    using PipelineParams = typename MainloopPipeline::Params;
         | 
| 158 | 
            -
                    using PipelineState = typename MainloopPipeline::PipelineState;
         | 
| 159 | 
            -
                    using MainloopPipeline_dO = typename CollectiveMainloop::MainloopPipeline_dO;
         | 
| 160 | 
            -
                    using PipelineParams_dO = typename MainloopPipeline_dO::Params;
         | 
| 161 | 
            -
                    using PipelineState_dO = typename MainloopPipeline_dO::PipelineState;
         | 
| 162 | 
            -
                    static constexpr bool Q_dO_same_stages = std::is_same_v<MainloopPipeline, MainloopPipeline_dO>;
         | 
| 163 | 
            -
             | 
| 164 | 
            -
                    SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
         | 
| 165 | 
            -
             | 
| 166 | 
            -
                    int const lane_predicate = cute::elect_one_sync();
         | 
| 167 | 
            -
                    int const warp_idx = cutlass::canonical_warp_idx_sync();
         | 
| 168 | 
            -
             | 
| 169 | 
            -
                    // Issue Tma Descriptor Prefetch from a single thread
         | 
| 170 | 
            -
                    if (warp_idx == 0 && lane_predicate) {
         | 
| 171 | 
            -
                        CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
         | 
| 172 | 
            -
                        CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
         | 
| 173 | 
            -
                    }
         | 
| 174 | 
            -
             | 
| 175 | 
            -
                    // Obtain warp index
         | 
| 176 | 
            -
                    int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
         | 
| 177 | 
            -
             | 
| 178 | 
            -
                    PipelineParams pipeline_params;
         | 
| 179 | 
            -
                    pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE;
         | 
| 180 | 
            -
                    int warp_group_idx = cutlass::canonical_warp_group_idx();
         | 
| 181 | 
            -
                    pipeline_params.role = warp_group_idx == 0
         | 
| 182 | 
            -
                        ? MainloopPipeline::ThreadCategory::Producer
         | 
| 183 | 
            -
                        : MainloopPipeline::ThreadCategory::Consumer;
         | 
| 184 | 
            -
                    pipeline_params.is_leader = warp_group_thread_idx == 0;
         | 
| 185 | 
            -
                    pipeline_params.num_consumers = NumMmaThreads;
         | 
| 186 | 
            -
             | 
| 187 | 
            -
                    if (warp_idx == 0 && lane_predicate) {
         | 
| 188 | 
            -
                        shared_storage.pipelines.barrier_KV.init(1 /*numThreads*/);
         | 
| 189 | 
            -
                    }
         | 
| 190 | 
            -
                    // We're counting on pipeline_q to call cutlass::arch::fence_barrier_init();
         | 
| 191 | 
            -
                    MainloopPipeline pipeline_q(shared_storage.pipelines.pipeline_q, pipeline_params, ClusterShape{});
         | 
| 192 | 
            -
                    auto role_dO = warp_group_idx == 0
         | 
| 193 | 
            -
                        ? MainloopPipeline_dO::ThreadCategory::Producer
         | 
| 194 | 
            -
                        : MainloopPipeline_dO::ThreadCategory::Consumer;
         | 
| 195 | 
            -
                    PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers};
         | 
| 196 | 
            -
                    MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return<Q_dO_same_stages>(pipeline_params, pipeline_params_dO), ClusterShape{});
         | 
| 197 | 
            -
             | 
| 198 | 
            -
                    CollectiveMainloop mainloop;
         | 
| 199 | 
            -
                    CollectiveEpilogue epilogue;
         | 
| 200 | 
            -
             | 
| 201 | 
            -
                    // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
         | 
| 202 | 
            -
                    if constexpr (size(ClusterShape{}) > 1) {
         | 
| 203 | 
            -
                        cute::cluster_arrive_relaxed();
         | 
| 204 | 
            -
                        cute::cluster_wait();
         | 
| 205 | 
            -
                    } else {
         | 
| 206 | 
            -
                        __syncthreads();
         | 
| 207 | 
            -
                    }
         | 
| 208 | 
            -
             | 
| 209 | 
            -
                    TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
         | 
| 210 | 
            -
             | 
| 211 | 
            -
                    if (warp_group_idx == 0) {  // Producer
         | 
| 212 | 
            -
                        cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
         | 
| 213 | 
            -
             | 
| 214 | 
            -
                        int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
         | 
| 215 | 
            -
                        if (warp_idx_in_warpgroup == 0) {  // Load K, V, and do TMA on Q and dO
         | 
| 216 | 
            -
                            PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
         | 
| 217 | 
            -
                            PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state<MainloopPipeline_dO>();
         | 
| 218 | 
            -
                            for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler);
         | 
| 219 | 
            -
                                 work_tile_info.is_valid(params.scheduler);
         | 
| 220 | 
            -
                                 work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info)) {
         | 
| 221 | 
            -
                                auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
         | 
| 222 | 
            -
                                auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
         | 
| 223 | 
            -
                                cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
         | 
| 224 | 
            -
                                auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() {
         | 
| 225 | 
            -
                                    scheduler.prefetch_next_work(params.scheduler, work_tile_info);
         | 
| 226 | 
            -
                                };
         | 
| 227 | 
            -
                                mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write,
         | 
| 228 | 
            -
                                              smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord);
         | 
| 229 | 
            -
                            }
         | 
| 230 | 
            -
                            mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do);
         | 
| 231 | 
            -
                        } else if (warp_idx_in_warpgroup == 1) {
         | 
| 232 | 
            -
                            for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
         | 
| 233 | 
            -
                                 work_tile_info.is_valid(params.scheduler);
         | 
| 234 | 
            -
                                 work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
         | 
| 235 | 
            -
                                auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
         | 
| 236 | 
            -
                                auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
         | 
| 237 | 
            -
                                cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
         | 
| 238 | 
            -
                                mainloop.store_dq(params.mainloop, shared_storage, block_coord);
         | 
| 239 | 
            -
                            }
         | 
| 240 | 
            -
                        }
         | 
| 241 | 
            -
                    } else {  // Consumer
         | 
| 242 | 
            -
                        cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
         | 
| 243 | 
            -
                        // Initialize matmul objects.
         | 
| 244 | 
            -
                        TiledMmadKV tiled_mma_dKV;
         | 
| 245 | 
            -
             | 
| 246 | 
            -
                        PipelineState smem_pipe_read;
         | 
| 247 | 
            -
                        PipelineState_dO smem_pipe_read_do;
         | 
| 248 | 
            -
             | 
| 249 | 
            -
                        mainloop.mma_init();
         | 
| 250 | 
            -
                        scheduler.init_consumer();
         | 
| 251 | 
            -
             | 
| 252 | 
            -
                        int work_idx = 0;
         | 
| 253 | 
            -
                        CUTLASS_PRAGMA_NO_UNROLL
         | 
| 254 | 
            -
                        for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
         | 
| 255 | 
            -
                             work_tile_info.is_valid(params.scheduler);
         | 
| 256 | 
            -
                             work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
         | 
| 257 | 
            -
                            auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
         | 
| 258 | 
            -
                            auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
         | 
| 259 | 
            -
                            cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
         | 
| 260 | 
            -
             | 
| 261 | 
            -
                            // dK and dV output accumulator.
         | 
| 262 | 
            -
                            Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
         | 
| 263 | 
            -
                            Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
         | 
| 264 | 
            -
                            bool tile_valid = mainloop.mma(
         | 
| 265 | 
            -
                                params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do,
         | 
| 266 | 
            -
                                tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage);
         | 
| 267 | 
            -
                            if (tile_valid) {
         | 
| 268 | 
            -
                                epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
         | 
| 269 | 
            -
                                               threadIdx.x - NumCopyThreads, block_coord);
         | 
| 270 | 
            -
                            } else {
         | 
| 271 | 
            -
                                epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
         | 
| 272 | 
            -
                            }
         | 
| 273 | 
            -
             | 
| 274 | 
            -
                        }
         | 
| 275 | 
            -
                        epilogue.store_tail();
         | 
| 276 | 
            -
                    }
         | 
| 277 | 
            -
             | 
| 278 | 
            -
                }
         | 
| 279 | 
            -
             | 
| 280 | 
            -
            };
         | 
| 281 | 
            -
             | 
| 282 | 
            -
            } // namespace flash
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash_bwd_launch_template.h
    DELETED
    
    | @@ -1,377 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include "cute/tensor.hpp"
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            #include "cutlass/device_kernel.h"  // For device_kernel
         | 
| 10 | 
            -
            #include "cutlass/kernel_launch.h"  // For kernel_launch
         | 
| 11 | 
            -
            #include "cutlass/cluster_launch.hpp"  // For ClusterLauncher
         | 
| 12 | 
            -
             | 
| 13 | 
            -
            #include "static_switch.h"
         | 
| 14 | 
            -
            #include "flash.h"
         | 
| 15 | 
            -
            #include "flash_bwd_preprocess_kernel.h"
         | 
| 16 | 
            -
            #include "flash_bwd_postprocess_kernel.h"
         | 
| 17 | 
            -
            #include "tile_scheduler.hpp"
         | 
| 18 | 
            -
            #include "mainloop_bwd_sm90_tma_gmma_ws.hpp"
         | 
| 19 | 
            -
            #include "mainloop_bwd_sm80.hpp"
         | 
| 20 | 
            -
            #include "epilogue_bwd.hpp"
         | 
| 21 | 
            -
            #include "flash_bwd_kernel_sm90.h"
         | 
| 22 | 
            -
            #include "flash_bwd_kernel_sm80.h"
         | 
| 23 | 
            -
             | 
| 24 | 
            -
            using namespace cute;
         | 
| 25 | 
            -
             | 
| 26 | 
            -
            template <int Arch, int kHeadDim, int kBlockM, int kBlockN, typename Element,
         | 
| 27 | 
            -
                      bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool Deterministic, bool GQA,
         | 
| 28 | 
            -
                      int Stages_dO=2, int Stages_dS_or_QSm80=2,
         | 
| 29 | 
            -
                      bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
         | 
| 30 | 
            -
                      int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
         | 
| 31 | 
            -
                      bool V_in_regs=false>
         | 
| 32 | 
            -
            void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 33 | 
            -
                static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
         | 
| 34 | 
            -
                using ElementAccum = float;
         | 
| 35 | 
            -
                using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
         | 
| 36 | 
            -
             | 
| 37 | 
            -
                int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM);
         | 
| 38 | 
            -
                int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN);
         | 
| 39 | 
            -
                bool const is_varlen_q = params.cu_seqlens_q;
         | 
| 40 | 
            -
                bool const is_varlen_k = params.cu_seqlens_k;
         | 
| 41 | 
            -
                int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
         | 
| 42 | 
            -
                int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k;
         | 
| 43 | 
            -
                int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded;
         | 
| 44 | 
            -
                int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded;
         | 
| 45 | 
            -
                int batch_q = !is_varlen_q ? params.b : 1;
         | 
| 46 | 
            -
                int batch_k = !is_varlen_k ? params.b : 1;
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
         | 
| 49 | 
            -
                using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, ArchTag, /*Clear_dQaccum=*/true, Varlen>;
         | 
| 50 | 
            -
                typename PreprocessKernel::Arguments preprocess_args {
         | 
| 51 | 
            -
                    static_cast<Element const*>(params.o_ptr),
         | 
| 52 | 
            -
                    {seqlen_q, params.d, params.h, batch_q},  // shape_O
         | 
| 53 | 
            -
                    {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0},  // stride_O
         | 
| 54 | 
            -
                    static_cast<Element const*>(params.do_ptr),
         | 
| 55 | 
            -
                    {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0},  // stride_dO
         | 
| 56 | 
            -
                    static_cast<float*>(params.dsoftmax_sum),
         | 
| 57 | 
            -
                    {seqlen_q_rounded, params.h, batch_q},  // shape_dPsum
         | 
| 58 | 
            -
                    {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0},  // stride_dPsum
         | 
| 59 | 
            -
                    static_cast<float*>(params.softmax_lse_ptr),
         | 
| 60 | 
            -
                    {_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0},  // stride_LSE
         | 
| 61 | 
            -
                    static_cast<float*>(params.softmax_lse_log2_ptr),
         | 
| 62 | 
            -
                    {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0},  // stride_LSE_log2
         | 
| 63 | 
            -
                    static_cast<ElementAccum*>(params.dq_accum_ptr),
         | 
| 64 | 
            -
                    {seqlen_q_rounded * params.d_rounded, params.h, batch_q},  // shape_dQaccum
         | 
| 65 | 
            -
                    {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0},  // stride_dQaccum
         | 
| 66 | 
            -
                    params.b,
         | 
| 67 | 
            -
                    params.dq_semaphore,
         | 
| 68 | 
            -
                    params.cu_seqlens_q,
         | 
| 69 | 
            -
                    params.seqused_q
         | 
| 70 | 
            -
                };
         | 
| 71 | 
            -
                typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
         | 
| 72 | 
            -
                int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
         | 
| 73 | 
            -
                dim3 grid_m(num_m_block, params.h, params.b);
         | 
| 74 | 
            -
                cutlass::kernel_launch<PreprocessKernel>(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/);
         | 
| 75 | 
            -
                CHECK_CUDA_KERNEL_LAUNCH();
         | 
| 76 | 
            -
             | 
| 77 | 
            -
                using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
         | 
| 78 | 
            -
                using ClusterShape = cute::Shape<_1, Int<1>, _1>;  // Currently doesn't not support cluster
         | 
| 79 | 
            -
                // Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80
         | 
| 80 | 
            -
                static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80;
         | 
| 81 | 
            -
                static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1;
         | 
| 82 | 
            -
                using CollectiveMainloop = std::conditional_t<
         | 
| 83 | 
            -
                    Arch >= 90,
         | 
| 84 | 
            -
                    flash::CollectiveMainloopBwdSm90<Stages, Stages_dO, Stages_dS, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
         | 
| 85 | 
            -
                        Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
         | 
| 86 | 
            -
                        SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>,
         | 
| 87 | 
            -
                    flash::CollectiveMainloopBwdSm80<Stages, Stages_dO, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm80,
         | 
| 88 | 
            -
                        Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
         | 
| 89 | 
            -
                        SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>
         | 
| 90 | 
            -
                >;
         | 
| 91 | 
            -
                using CollectiveEpilogue = std::conditional_t<
         | 
| 92 | 
            -
                    !GQA,
         | 
| 93 | 
            -
                    flash::CollectiveEpilogueBwd<TileShape_MNK, Element, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, dKV_swapAB, NumMmaWarpGroups * (Arch >= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>,
         | 
| 94 | 
            -
                    flash::CollectiveEpilogueBwdGQA<TileShape_MNK, ElementAccum, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, Deterministic>
         | 
| 95 | 
            -
                >;
         | 
| 96 | 
            -
                using Scheduler = flash::SingleTileScheduler<Varlen, false /*Split*/, false /*PackGQA*/, kBlockN>;
         | 
| 97 | 
            -
                using AttnKernel = std::conditional_t<
         | 
| 98 | 
            -
                    Arch >= 90,
         | 
| 99 | 
            -
                    flash::enable_sm90_or_later<flash::FlashAttnBwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
         | 
| 100 | 
            -
                    flash::enable_sm80_to_sm89<flash::FlashAttnBwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
         | 
| 101 | 
            -
                >;
         | 
| 102 | 
            -
             | 
| 103 | 
            -
                typename CollectiveMainloop::Arguments mainloop_args {
         | 
| 104 | 
            -
                    static_cast<Element const*>(params.q_ptr),
         | 
| 105 | 
            -
                    {seqlen_q, params.d, params.h, batch_q},  // shape_Q
         | 
| 106 | 
            -
                    {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0},  // stride_Q
         | 
| 107 | 
            -
                    static_cast<Element const*>(params.k_ptr),
         | 
| 108 | 
            -
                    {seqlen_k, params.d, params.h_k, batch_k},  // shape_K
         | 
| 109 | 
            -
                    {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0},  // stride_K
         | 
| 110 | 
            -
                    static_cast<Element const*>(params.v_ptr),
         | 
| 111 | 
            -
                    {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0},  // stride_V
         | 
| 112 | 
            -
                    static_cast<Element const*>(params.do_ptr),
         | 
| 113 | 
            -
                    {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0},  // stride_dO
         | 
| 114 | 
            -
                    static_cast<ElementAccum*>(params.dq_accum_ptr),
         | 
| 115 | 
            -
                    {seqlen_q_rounded * params.d_rounded, params.h, batch_q},  // shape_dQaccum
         | 
| 116 | 
            -
                    {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
         | 
| 117 | 
            -
                    static_cast<float*>(params.softmax_lse_log2_ptr),
         | 
| 118 | 
            -
                    {seqlen_q_rounded, params.h, batch_q},  // shape_LSE
         | 
| 119 | 
            -
                    {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0},  // stride_LSE_log2
         | 
| 120 | 
            -
                    static_cast<float*>(params.dsoftmax_sum),
         | 
| 121 | 
            -
                    {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0},  // stride_dPsum
         | 
| 122 | 
            -
                    params.scale_softmax,
         | 
| 123 | 
            -
                    params.window_size_left, params.window_size_right,
         | 
| 124 | 
            -
                    params.softcap,
         | 
| 125 | 
            -
                    params.b,
         | 
| 126 | 
            -
                    params.dq_semaphore,
         | 
| 127 | 
            -
                    params.cu_seqlens_q, params.cu_seqlens_k,
         | 
| 128 | 
            -
                    params.seqused_q, params.seqused_k
         | 
| 129 | 
            -
                };
         | 
| 130 | 
            -
                // The case work with GQA is ugly but idk how to fix it.
         | 
| 131 | 
            -
                typename CollectiveEpilogue::Arguments epilogue_args {
         | 
| 132 | 
            -
                    static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dk_ptr : params.dk_accum_ptr),
         | 
| 133 | 
            -
                    [&] {
         | 
| 134 | 
            -
                        if constexpr (!GQA) {
         | 
| 135 | 
            -
                            return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k};  // shape_dK
         | 
| 136 | 
            -
                        } else {
         | 
| 137 | 
            -
                            return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k};  // shape_dKaccum
         | 
| 138 | 
            -
                        }
         | 
| 139 | 
            -
                    }(),
         | 
| 140 | 
            -
                    [&] {
         | 
| 141 | 
            -
                        if constexpr (!GQA) {
         | 
| 142 | 
            -
                            return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0};  // stride_dK
         | 
| 143 | 
            -
                        } else {
         | 
| 144 | 
            -
                            return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0};  // stride_dKaccum
         | 
| 145 | 
            -
                        }
         | 
| 146 | 
            -
                    }(),
         | 
| 147 | 
            -
                    static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dv_ptr : params.dv_accum_ptr),
         | 
| 148 | 
            -
                    [&] {
         | 
| 149 | 
            -
                        if constexpr (!GQA) {
         | 
| 150 | 
            -
                            return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0};  // stride_dV
         | 
| 151 | 
            -
                        } else {
         | 
| 152 | 
            -
                            return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0};  // stride_dVaccum
         | 
| 153 | 
            -
                        }
         | 
| 154 | 
            -
                    }(),
         | 
| 155 | 
            -
                    params.h,
         | 
| 156 | 
            -
                    params.dk_semaphore,
         | 
| 157 | 
            -
                    params.dv_semaphore,
         | 
| 158 | 
            -
                    params.cu_seqlens_k,
         | 
| 159 | 
            -
                    params.seqused_k,
         | 
| 160 | 
            -
                };
         | 
| 161 | 
            -
             | 
| 162 | 
            -
                int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{}));
         | 
| 163 | 
            -
                num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{}));
         | 
| 164 | 
            -
                typename flash::TileSchedulerArguments scheduler_args {
         | 
| 165 | 
            -
                    num_blocks_n, params.h, params.b, 1 /*num_splits*/,
         | 
| 166 | 
            -
                    params.h / params.h_k,
         | 
| 167 | 
            -
                    params.seqlen_k,
         | 
| 168 | 
            -
                    params.seqlen_q, params.d, params.dv, sizeof(Element),
         | 
| 169 | 
            -
                    params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k
         | 
| 170 | 
            -
                };
         | 
| 171 | 
            -
             | 
| 172 | 
            -
                int device;
         | 
| 173 | 
            -
                cudaGetDevice(&device);
         | 
| 174 | 
            -
                typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
         | 
| 175 | 
            -
                    mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
         | 
| 176 | 
            -
                });
         | 
| 177 | 
            -
             | 
| 178 | 
            -
                dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
         | 
| 179 | 
            -
                dim3 block_dims = AttnKernel::get_block_shape();
         | 
| 180 | 
            -
                int smem_size = AttnKernel::SharedStorageSize;
         | 
| 181 | 
            -
                // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
         | 
| 182 | 
            -
                // int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do));
         | 
| 183 | 
            -
                // int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds));
         | 
| 184 | 
            -
                // int smem_size_dqacc = [&] {
         | 
| 185 | 
            -
                //     if constexpr (Arch >= 90) {
         | 
| 186 | 
            -
                //         return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc));
         | 
| 187 | 
            -
                //     } else {
         | 
| 188 | 
            -
                //         return 0;
         | 
| 189 | 
            -
                //     }
         | 
| 190 | 
            -
                // }();
         | 
| 191 | 
            -
                // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
         | 
| 192 | 
            -
                // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
         | 
| 193 | 
            -
                // int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse));
         | 
| 194 | 
            -
                // int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum));
         | 
| 195 | 
            -
                // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum);
         | 
| 196 | 
            -
                if constexpr (size(ClusterShape{}) > 1) {
         | 
| 197 | 
            -
                    void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
         | 
| 198 | 
            -
                    if (smem_size >= 48 * 1024) {
         | 
| 199 | 
            -
                        CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
         | 
| 200 | 
            -
                    }
         | 
| 201 | 
            -
                    dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
         | 
| 202 | 
            -
                    cutlass::ClusterLauncher::launch(
         | 
| 203 | 
            -
                        grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/);
         | 
| 204 | 
            -
                } else {
         | 
| 205 | 
            -
                    if (smem_size >= 48 * 1024) {
         | 
| 206 | 
            -
                        CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<AttnKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
         | 
| 207 | 
            -
                    }
         | 
| 208 | 
            -
                    cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/);
         | 
| 209 | 
            -
                }
         | 
| 210 | 
            -
                CHECK_CUDA_KERNEL_LAUNCH();
         | 
| 211 | 
            -
             | 
| 212 | 
            -
                using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, ArchTag,
         | 
| 213 | 
            -
                    AttnKernel::CollectiveMainloop::NumMmaThreads,
         | 
| 214 | 
            -
                    typename AttnKernel::CollectiveMainloop::TiledMmadQ,
         | 
| 215 | 
            -
                    AttnKernel::CollectiveMainloop::dQ_swapAB
         | 
| 216 | 
            -
                    >;
         | 
| 217 | 
            -
                typename PostprocessKernel::Arguments postprocess_args {
         | 
| 218 | 
            -
                    static_cast<ElementAccum const*>(params.dq_accum_ptr),
         | 
| 219 | 
            -
                    {seqlen_q_rounded * params.d_rounded, params.h, batch_q},  // shape_dQaccum
         | 
| 220 | 
            -
                    {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
         | 
| 221 | 
            -
                    static_cast<Element*>(params.dq_ptr),
         | 
| 222 | 
            -
                    {seqlen_q, params.d, params.h, batch_q},  // shape_dQ
         | 
| 223 | 
            -
                    {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride},  // stride_dQ
         | 
| 224 | 
            -
                    params.scale_softmax,
         | 
| 225 | 
            -
                    params.cu_seqlens_q,
         | 
| 226 | 
            -
                    params.seqused_q
         | 
| 227 | 
            -
                };
         | 
| 228 | 
            -
                typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
         | 
| 229 | 
            -
                int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));
         | 
| 230 | 
            -
                dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b);
         | 
| 231 | 
            -
                int smem_size_postprocess = PostprocessKernel::SharedStorageSize;
         | 
| 232 | 
            -
                if (smem_size_postprocess >= 48 * 1024) {
         | 
| 233 | 
            -
                    CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
         | 
| 234 | 
            -
                }
         | 
| 235 | 
            -
                cutlass::kernel_launch<PostprocessKernel>(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/);
         | 
| 236 | 
            -
                CHECK_CUDA_KERNEL_LAUNCH();
         | 
| 237 | 
            -
             | 
| 238 | 
            -
                if constexpr (GQA) {
         | 
| 239 | 
            -
                    using TileShape_NK = cute::Shape<Int<kBlockN>, Int<kHeadDim>>;
         | 
| 240 | 
            -
                    using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_NK, Element, ElementAccum, ArchTag,
         | 
| 241 | 
            -
                        AttnKernel::CollectiveEpilogue::NumEpilogueThreads,
         | 
| 242 | 
            -
                        typename AttnKernel::CollectiveMainloop::TiledMmadKV,
         | 
| 243 | 
            -
                        AttnKernel::CollectiveMainloop::dKV_swapAB
         | 
| 244 | 
            -
                        >;
         | 
| 245 | 
            -
                    typename PostprocessKerneldKV::Arguments postprocess_dK_args {
         | 
| 246 | 
            -
                        static_cast<ElementAccum const*>(params.dk_accum_ptr),
         | 
| 247 | 
            -
                        {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k},  // shape_dKaccum
         | 
| 248 | 
            -
                        {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0},  // stride_dKaccum
         | 
| 249 | 
            -
                        static_cast<Element*>(params.dk_ptr),
         | 
| 250 | 
            -
                        {seqlen_k, params.d, params.h_k, batch_k},  // shape_dK
         | 
| 251 | 
            -
                        {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride},  // stride_dK
         | 
| 252 | 
            -
                        1.f,
         | 
| 253 | 
            -
                        params.cu_seqlens_k,
         | 
| 254 | 
            -
                        params.seqused_k
         | 
| 255 | 
            -
                    };
         | 
| 256 | 
            -
                    typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args);
         | 
| 257 | 
            -
                    typename PostprocessKerneldKV::Arguments postprocess_dV_args {
         | 
| 258 | 
            -
                        static_cast<ElementAccum const*>(params.dv_accum_ptr),
         | 
| 259 | 
            -
                        {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k},  // shape_dVaccum
         | 
| 260 | 
            -
                        {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0},  // stride_dVaccum
         | 
| 261 | 
            -
                        static_cast<Element*>(params.dv_ptr),
         | 
| 262 | 
            -
                        {seqlen_k, params.d, params.h_k, batch_k},  // shape_dV
         | 
| 263 | 
            -
                        {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride},  // stride_dV
         | 
| 264 | 
            -
                        1.f,
         | 
| 265 | 
            -
                        params.cu_seqlens_k,
         | 
| 266 | 
            -
                        params.seqused_k
         | 
| 267 | 
            -
                    };
         | 
| 268 | 
            -
                    typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args);
         | 
| 269 | 
            -
                    int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{}));
         | 
| 270 | 
            -
                    dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b);
         | 
| 271 | 
            -
                    int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize;
         | 
| 272 | 
            -
                    if (smem_size_postprocess >= 48 * 1024) {
         | 
| 273 | 
            -
                        CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKerneldKV>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
         | 
| 274 | 
            -
                    }
         | 
| 275 | 
            -
                    cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/);
         | 
| 276 | 
            -
                    CHECK_CUDA_KERNEL_LAUNCH();
         | 
| 277 | 
            -
                    cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/);
         | 
| 278 | 
            -
                    CHECK_CUDA_KERNEL_LAUNCH();
         | 
| 279 | 
            -
                }
         | 
| 280 | 
            -
             | 
| 281 | 
            -
            }
         | 
| 282 | 
            -
             | 
| 283 | 
            -
            template<int Arch, typename T, int kBlockM, int kBlockN, int kHeadDim, bool Is_causal, bool Is_local, bool Has_softcap,
         | 
| 284 | 
            -
                     int Stages_dO=2, int Stages_dS_or_QSm80=2,
         | 
| 285 | 
            -
                     bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
         | 
| 286 | 
            -
                     int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
         | 
| 287 | 
            -
                     bool V_in_regs=false>
         | 
| 288 | 
            -
            void run_mha_bwd_dispatch(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 289 | 
            -
                VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
         | 
| 290 | 
            -
                    BOOL_SWITCH(params.h != params.h_k, GQA, [&] {
         | 
| 291 | 
            -
            //             BOOL_SWITCH(params.deterministic, Deterministic, [&] {
         | 
| 292 | 
            -
                        // run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
         | 
| 293 | 
            -
                        run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
         | 
| 294 | 
            -
            //             });
         | 
| 295 | 
            -
                    });
         | 
| 296 | 
            -
                });
         | 
| 297 | 
            -
            }
         | 
| 298 | 
            -
             | 
| 299 | 
            -
             | 
| 300 | 
            -
            template<int Arch, typename T, bool Has_softcap>
         | 
| 301 | 
            -
            void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 302 | 
            -
                CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
         | 
| 303 | 
            -
                    if constexpr (Arch >= 90) {
         | 
| 304 | 
            -
                        if constexpr (Is_causal && Has_softcap) {
         | 
| 305 | 
            -
                            // register spill with 128 x 128
         | 
| 306 | 
            -
                            run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 2, false>(params, stream);
         | 
| 307 | 
            -
                        } else {
         | 
| 308 | 
            -
                            // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
         | 
| 309 | 
            -
                            run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 2, false>(params, stream);
         | 
| 310 | 
            -
                        }
         | 
| 311 | 
            -
                    } else if constexpr (Arch == 86 || Arch == 89) {
         | 
| 312 | 
            -
                        run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
         | 
| 313 | 
            -
                        // run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);
         | 
| 314 | 
            -
                        // run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);
         | 
| 315 | 
            -
                        // run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 1, 8, 4, false>(params, stream);
         | 
| 316 | 
            -
                    } else {
         | 
| 317 | 
            -
                        run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 4, 4, 4, false>(params, stream);
         | 
| 318 | 
            -
                    }
         | 
| 319 | 
            -
                });
         | 
| 320 | 
            -
            }
         | 
| 321 | 
            -
             | 
| 322 | 
            -
            template<int Arch, typename T, bool Has_softcap>
         | 
| 323 | 
            -
            void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 324 | 
            -
                CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
         | 
| 325 | 
            -
                    if constexpr (Arch >= 90) {
         | 
| 326 | 
            -
                        run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, true>(params, stream);
         | 
| 327 | 
            -
                    } else if constexpr (Arch == 86 || Arch == 89) {
         | 
| 328 | 
            -
                        run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
         | 
| 329 | 
            -
                    } else {
         | 
| 330 | 
            -
                        run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false>(params, stream);
         | 
| 331 | 
            -
                    }
         | 
| 332 | 
            -
                });
         | 
| 333 | 
            -
            }
         | 
| 334 | 
            -
             | 
| 335 | 
            -
            template<int Arch, typename T, bool Has_softcap>
         | 
| 336 | 
            -
            void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 337 | 
            -
                CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
         | 
| 338 | 
            -
                    if constexpr (Arch >= 90) {
         | 
| 339 | 
            -
                        if constexpr (Is_causal || Is_local || Has_softcap) {
         | 
| 340 | 
            -
                            run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);
         | 
| 341 | 
            -
                        } else {
         | 
| 342 | 
            -
                            run_mha_bwd_dispatch<Arch, T, 80, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
         | 
| 343 | 
            -
                        }
         | 
| 344 | 
            -
                    } else if constexpr (Arch == 86 || Arch == 89) {
         | 
| 345 | 
            -
                        run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true>(params, stream);
         | 
| 346 | 
            -
                    } else {
         | 
| 347 | 
            -
                        run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false>(params, stream);
         | 
| 348 | 
            -
                    }
         | 
| 349 | 
            -
                });
         | 
| 350 | 
            -
            }
         | 
| 351 | 
            -
             | 
| 352 | 
            -
            template<int Arch, typename T, bool Has_softcap>
         | 
| 353 | 
            -
            void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 354 | 
            -
                CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
         | 
| 355 | 
            -
                    if constexpr (Arch >= 90) {
         | 
| 356 | 
            -
                        run_mha_bwd_dispatch<Arch, T, 64, 96, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
         | 
| 357 | 
            -
                    } else if constexpr (Arch == 86 || Arch == 89) {
         | 
| 358 | 
            -
                        run_mha_bwd_dispatch<Arch, T, 64, 64, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 2, true>(params, stream);
         | 
| 359 | 
            -
                    } else {
         | 
| 360 | 
            -
                        run_mha_bwd_dispatch<Arch, T, 64, 80, 192, Is_causal, Is_local, Has_softcap, 1, 2, false, true, false, 2, 4, 2, 2, false>(params, stream);
         | 
| 361 | 
            -
                    }
         | 
| 362 | 
            -
                });
         | 
| 363 | 
            -
            }
         | 
| 364 | 
            -
             | 
| 365 | 
            -
            template<int Arch, typename T, bool Has_softcap>
         | 
| 366 | 
            -
            void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 367 | 
            -
                CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
         | 
| 368 | 
            -
                    if constexpr (Arch >= 90) {
         | 
| 369 | 
            -
                        run_mha_bwd_dispatch<Arch, T, 64, 80, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
         | 
| 370 | 
            -
                    } else if constexpr (Arch == 86 || Arch == 89) {
         | 
| 371 | 
            -
                        run_mha_bwd_dispatch<Arch, T, 32, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 1, true>(params, stream);
         | 
| 372 | 
            -
                        // run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 1, 2, true>(params, stream);
         | 
| 373 | 
            -
                    } else {
         | 
| 374 | 
            -
                        run_mha_bwd_dispatch<Arch, T, 64, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 2, 2, false>(params, stream);
         | 
| 375 | 
            -
                    }
         | 
| 376 | 
            -
                });
         | 
| 377 | 
            -
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash_bwd_postprocess_kernel.h
    DELETED
    
    | @@ -1,256 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include "cute/tensor.hpp"
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            #include <cutlass/cutlass.h>
         | 
| 10 | 
            -
            #include <cutlass/array.h>
         | 
| 11 | 
            -
            #include <cutlass/numeric_types.h>
         | 
| 12 | 
            -
            #include <cutlass/numeric_conversion.h>
         | 
| 13 | 
            -
            #include "cutlass/arch/barrier.h"
         | 
| 14 | 
            -
             | 
| 15 | 
            -
            #include "seqlen.h"
         | 
| 16 | 
            -
            #include "utils.h"
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            namespace flash {
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            using namespace cute;
         | 
| 21 | 
            -
             | 
| 22 | 
            -
            template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class TiledMma, bool dQ_swapAB>
         | 
| 23 | 
            -
            class FlashAttnBwdPostprocessConvertdQ {
         | 
| 24 | 
            -
             | 
| 25 | 
            -
            public:
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                // Type Aliases
         | 
| 28 | 
            -
                using TileShape_MK = TileShape_MK_;
         | 
| 29 | 
            -
                using ArchTag = ArchTag_;
         | 
| 30 | 
            -
             | 
| 31 | 
            -
                static_assert(ArchTag::kMinComputeCapability >= 75);
         | 
| 32 | 
            -
                static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90;
         | 
| 33 | 
            -
             | 
| 34 | 
            -
                static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
         | 
| 35 | 
            -
                static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
         | 
| 36 | 
            -
             | 
| 37 | 
            -
                static constexpr int kBlockM = get<0>(TileShape_MK{});
         | 
| 38 | 
            -
                static constexpr int kHeadDim = get<1>(TileShape_MK{});
         | 
| 39 | 
            -
                static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup");
         | 
| 40 | 
            -
                static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup;
         | 
| 41 | 
            -
                using R2SLayoutAtomdQaccum = std::conditional_t<
         | 
| 42 | 
            -
                    IsSm90,
         | 
| 43 | 
            -
                    Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumdQWarpGgroups>>>,
         | 
| 44 | 
            -
                    Layout<Shape<Int<kNThreads>>>
         | 
| 45 | 
            -
                >;
         | 
| 46 | 
            -
                using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},
         | 
| 47 | 
            -
                                                                     Layout<Shape<Int<IsSm90 ? 4 : 1>>>{}));  // Val layout, 1 or 4 vals per read
         | 
| 48 | 
            -
                using G2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreads>>>;
         | 
| 49 | 
            -
                // UniversalCopy instead of AutoVectorizingCopyWithAssumedAlignment as the latter generates cp.async instructions
         | 
| 50 | 
            -
                using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, ElementAccum>{}, G2SLayoutAtomdQaccum{},
         | 
| 51 | 
            -
                                                                     Layout<Shape<_4>>{}));  // Val layout, 4 vals per read
         | 
| 52 | 
            -
                // We don't do bound checking for the gmem -> smem load so we just assert here.
         | 
| 53 | 
            -
                static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0);
         | 
| 54 | 
            -
                static constexpr int SmemdQaccumSize = size(TileShape_MK{});
         | 
| 55 | 
            -
                using SmemLayoutdQaccumFlat = Layout<Shape<Int<SmemdQaccumSize>>>;
         | 
| 56 | 
            -
                using SmemLayoutdQaccum = std::conditional_t<
         | 
| 57 | 
            -
                    IsSm90,
         | 
| 58 | 
            -
                    Layout<Shape<Int<kBlockM * kHeadDim / NumdQWarpGgroups>, Int<NumdQWarpGgroups>>>,
         | 
| 59 | 
            -
                    Layout<Shape<Int<kBlockM * kHeadDim>>>
         | 
| 60 | 
            -
                >;
         | 
| 61 | 
            -
             | 
| 62 | 
            -
                // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
         | 
| 63 | 
            -
                // then setting kBlockKSmem to 32 will cause "Static shape_div failure".
         | 
| 64 | 
            -
                // We want to treat it as 64 x 48, so kBlockKSmem should be 16.
         | 
| 65 | 
            -
                static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{});
         | 
| 66 | 
            -
                static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16);
         | 
| 67 | 
            -
                static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
         | 
| 68 | 
            -
                using SmemLayoutAtomdQ =
         | 
| 69 | 
            -
                    decltype(composition(Swizzle<kSwizzle, 3, 3>{},
         | 
| 70 | 
            -
                             Layout<Shape<Int<8>, Int<kBlockKSmem>>,
         | 
| 71 | 
            -
                             Stride<Int<kBlockKSmem>, _1>>{}));
         | 
| 72 | 
            -
                using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{}));
         | 
| 73 | 
            -
                using SmemLayoutdQt =
         | 
| 74 | 
            -
                    decltype(cute::composition(SmemLayoutdQ{},
         | 
| 75 | 
            -
                                               make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})),
         | 
| 76 | 
            -
                                                           make_stride(Int<get<0>(TileShape_MK{})>{}, _1{}))));
         | 
| 77 | 
            -
             | 
| 78 | 
            -
                using SmemCopyAtomdQ = Copy_Atom<
         | 
| 79 | 
            -
                    std::conditional_t<
         | 
| 80 | 
            -
                        IsSm90,
         | 
| 81 | 
            -
                        std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
         | 
| 82 | 
            -
                        AutoVectorizingCopyWithAssumedAlignment<128>
         | 
| 83 | 
            -
                    >,
         | 
| 84 | 
            -
                    Element>;
         | 
| 85 | 
            -
             | 
| 86 | 
            -
                static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
         | 
| 87 | 
            -
                static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
         | 
| 88 | 
            -
                static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock));
         | 
| 89 | 
            -
                static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
         | 
| 90 | 
            -
                using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
         | 
| 91 | 
            -
                                              Stride<Int<kGmemThreadsPerRow>, _1>>;
         | 
| 92 | 
            -
                using GmemTiledCopy = decltype(
         | 
| 93 | 
            -
                    make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
         | 
| 94 | 
            -
                                    GmemLayoutAtom{},
         | 
| 95 | 
            -
                                    Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per load
         | 
| 96 | 
            -
             | 
| 97 | 
            -
                struct SharedStorage : cute::aligned_struct<128> {
         | 
| 98 | 
            -
                    cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>> smem_dqacc;
         | 
| 99 | 
            -
                    cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
         | 
| 100 | 
            -
                    alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum;
         | 
| 101 | 
            -
                };
         | 
| 102 | 
            -
             | 
| 103 | 
            -
                static constexpr int SharedStorageSize = sizeof(SharedStorage);
         | 
| 104 | 
            -
             | 
| 105 | 
            -
                using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>;   // (seqlen_q, d, head, batch)
         | 
| 106 | 
            -
                using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>;
         | 
| 107 | 
            -
                using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen_q * d, head, batch)
         | 
| 108 | 
            -
                using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
         | 
| 109 | 
            -
             | 
| 110 | 
            -
                // Device side arguments
         | 
| 111 | 
            -
                struct Arguments {
         | 
| 112 | 
            -
                    ElementAccum const* ptr_dQaccum;
         | 
| 113 | 
            -
                    ShapedQaccum const shape_dQaccum;
         | 
| 114 | 
            -
                    StridedQaccum const stride_dQaccum;
         | 
| 115 | 
            -
                    Element* ptr_dQ;
         | 
| 116 | 
            -
                    ShapedQ const shape_dQ;
         | 
| 117 | 
            -
                    StridedQ const stride_dQ;
         | 
| 118 | 
            -
                    float const softmax_scale;
         | 
| 119 | 
            -
                    int const* cu_seqlens = nullptr;
         | 
| 120 | 
            -
                    int const* seqused = nullptr;
         | 
| 121 | 
            -
                };
         | 
| 122 | 
            -
             | 
| 123 | 
            -
                // Kernel entry point API
         | 
| 124 | 
            -
                struct Params {
         | 
| 125 | 
            -
                    ElementAccum const* ptr_dQaccum;
         | 
| 126 | 
            -
                    ShapedQaccum const shape_dQaccum;
         | 
| 127 | 
            -
                    StridedQaccum const stride_dQaccum;
         | 
| 128 | 
            -
                    Element* ptr_dQ;
         | 
| 129 | 
            -
                    ShapedQ const shape_dQ;
         | 
| 130 | 
            -
                    StridedQ const stride_dQ;
         | 
| 131 | 
            -
                    float const softmax_scale;
         | 
| 132 | 
            -
                    int const* cu_seqlens = nullptr;
         | 
| 133 | 
            -
                    int const* seqused = nullptr;
         | 
| 134 | 
            -
                };
         | 
| 135 | 
            -
             | 
| 136 | 
            -
                // Convert to underlying arguments. In this case, a simple copy for the aliased type.
         | 
| 137 | 
            -
                static
         | 
| 138 | 
            -
                Params
         | 
| 139 | 
            -
                to_underlying_arguments(Arguments const& args) {
         | 
| 140 | 
            -
                    return {
         | 
| 141 | 
            -
                        args.ptr_dQaccum,
         | 
| 142 | 
            -
                        args.shape_dQaccum,
         | 
| 143 | 
            -
                        args.stride_dQaccum,
         | 
| 144 | 
            -
                        args.ptr_dQ,
         | 
| 145 | 
            -
                        args.shape_dQ,
         | 
| 146 | 
            -
                        args.stride_dQ,
         | 
| 147 | 
            -
                        args.softmax_scale,
         | 
| 148 | 
            -
                        args.cu_seqlens,
         | 
| 149 | 
            -
                        args.seqused
         | 
| 150 | 
            -
                    };
         | 
| 151 | 
            -
                }
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                CUTLASS_DEVICE
         | 
| 154 | 
            -
                void
         | 
| 155 | 
            -
                operator()(Params const& params, char* smem_buf) {
         | 
| 156 | 
            -
             | 
| 157 | 
            -
                    static constexpr int kBlockM = get<0>(TileShape_MK{});
         | 
| 158 | 
            -
                    SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
         | 
| 159 | 
            -
             | 
| 160 | 
            -
                    Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{});
         | 
| 161 | 
            -
                    Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{});
         | 
| 162 | 
            -
                    Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{});
         | 
| 163 | 
            -
                    Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{});
         | 
| 164 | 
            -
             | 
| 165 | 
            -
                    int const thread_idx = threadIdx.x;
         | 
| 166 | 
            -
                    int const m_block = blockIdx.x;
         | 
| 167 | 
            -
                    int const bidh = blockIdx.y;
         | 
| 168 | 
            -
                    int const bidb = blockIdx.z;
         | 
| 169 | 
            -
             | 
| 170 | 
            -
                    flash::SeqlenInfo<true /*Varlen*/, kBlockM> seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused);
         | 
| 171 | 
            -
                    bool const is_varlen = params.cu_seqlens;
         | 
| 172 | 
            -
                    if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; }
         | 
| 173 | 
            -
             | 
| 174 | 
            -
                    // Step 1: load dQaccum from gmem to smem
         | 
| 175 | 
            -
                    Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum const*>(params.ptr_dQaccum)),
         | 
| 176 | 
            -
                                                  params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
         | 
| 177 | 
            -
                    Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block));  // (M * K)
         | 
| 178 | 
            -
                    if constexpr (IsSm90) {  // Use BulkCopy
         | 
| 179 | 
            -
                        static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v<ElementAccum> / 8);
         | 
| 180 | 
            -
                        auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};
         | 
| 181 | 
            -
                        // if (thread0()) { print(gdQaccum); printf("\n"); print(sdQaccum_flat); printf("\n"); }
         | 
| 182 | 
            -
                        if (thread_idx == 0) {
         | 
| 183 | 
            -
                            shared_storage.barrier_dQaccum.init(1 /*numThreads*/);
         | 
| 184 | 
            -
                            shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
         | 
| 185 | 
            -
                            copy(bulk_copy.with(*reinterpret_cast<uint64_t*>(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat);
         | 
| 186 | 
            -
                        }
         | 
| 187 | 
            -
                        __syncthreads();
         | 
| 188 | 
            -
                        shared_storage.barrier_dQaccum.wait(0);
         | 
| 189 | 
            -
                    } else {
         | 
| 190 | 
            -
                        G2STiledCopydQaccum g2s_tiled_copy_dQaccum;
         | 
| 191 | 
            -
                        auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
         | 
| 192 | 
            -
                        Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum);
         | 
| 193 | 
            -
                        Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum);
         | 
| 194 | 
            -
                        cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s);
         | 
| 195 | 
            -
                        __syncthreads();
         | 
| 196 | 
            -
                    }
         | 
| 197 | 
            -
             | 
| 198 | 
            -
                    // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); }
         | 
| 199 | 
            -
             | 
| 200 | 
            -
                    // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16
         | 
| 201 | 
            -
                    R2STiledCopydQaccum s2r_tiled_copy_dQaccum;
         | 
| 202 | 
            -
                    auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx);
         | 
| 203 | 
            -
                    Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum);
         | 
| 204 | 
            -
                    TiledMma tiled_mma_dQ;
         | 
| 205 | 
            -
                    Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 1, !dQ_swapAB ? 1 : 0>(TileShape_MK{}));
         | 
| 206 | 
            -
                    // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); }
         | 
| 207 | 
            -
                    // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); }
         | 
| 208 | 
            -
                    // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); }
         | 
| 209 | 
            -
                    CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum));
         | 
| 210 | 
            -
                    Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum);
         | 
| 211 | 
            -
                    cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
         | 
| 212 | 
            -
                    #pragma unroll
         | 
| 213 | 
            -
                    for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; }
         | 
| 214 | 
            -
                    // Convert tdQrdQ from fp32 to fp16
         | 
| 215 | 
            -
                    Tensor rdQ = make_tensor_like<Element>(taccdQrdQaccum);
         | 
| 216 | 
            -
                    flash::convert_type_out(taccdQrdQaccum, rdQ);
         | 
| 217 | 
            -
             | 
| 218 | 
            -
                    // Step 3: Copy dQ from register to smem
         | 
| 219 | 
            -
                    auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ);
         | 
| 220 | 
            -
                    auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx);
         | 
| 221 | 
            -
                    Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ);  // ((Atom,AtomNum), MMA_N, MMA_N)
         | 
| 222 | 
            -
                    // if (cute::thread0()) { print(smem_tiled_copy_dQ); }
         | 
| 223 | 
            -
                    // if (cute::thread0()) { print(smem_thr_copy_dQ); }
         | 
| 224 | 
            -
                    // if (cute::thread0()) { print(sdQ); }
         | 
| 225 | 
            -
                    Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return<!dQ_swapAB>(sdQ, sdQt));  // ((Atom,AtomNum),PIPE_M,PIPE_N)
         | 
| 226 | 
            -
                    cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
         | 
| 227 | 
            -
                    __syncthreads();
         | 
| 228 | 
            -
             | 
| 229 | 
            -
                    // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
         | 
| 230 | 
            -
                    Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0);
         | 
| 231 | 
            -
                    Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{}));  // (M, K)
         | 
| 232 | 
            -
                    GmemTiledCopy gmem_tiled_copy_dQ;
         | 
| 233 | 
            -
                    auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx);
         | 
| 234 | 
            -
                    Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ);    // ((Atom,AtomNum),ATOM_M,ATOM_N)
         | 
| 235 | 
            -
                    Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
         | 
| 236 | 
            -
             | 
| 237 | 
            -
                    Tensor tdQrdQ = make_fragment_like(tdQsdQ);
         | 
| 238 | 
            -
                    Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{}));
         | 
| 239 | 
            -
                    Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
         | 
| 240 | 
            -
                    #pragma unroll
         | 
| 241 | 
            -
                    for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); }
         | 
| 242 | 
            -
                    // Need to check OOB when reading from smem if kBlockM isn't evenly tiled
         | 
| 243 | 
            -
                    static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
         | 
| 244 | 
            -
                    flash::copy</*Is_even_MN=*/EvenM, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
         | 
| 245 | 
            -
                        gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM);
         | 
| 246 | 
            -
             | 
| 247 | 
            -
                    // Step 5: Copy dQ from register to gmem
         | 
| 248 | 
            -
                    // Clear_OOB_K must be false since we don't want to write zeros to gmem
         | 
| 249 | 
            -
                    flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         | 
| 250 | 
            -
                        gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM)
         | 
| 251 | 
            -
                    );
         | 
| 252 | 
            -
                }
         | 
| 253 | 
            -
             | 
| 254 | 
            -
            };
         | 
| 255 | 
            -
             | 
| 256 | 
            -
            } // namespace flash
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash_bwd_preprocess_kernel.h
    DELETED
    
    | @@ -1,252 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include "cute/tensor.hpp"
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            #include <cutlass/cutlass.h>
         | 
| 10 | 
            -
            #include <cutlass/array.h>
         | 
| 11 | 
            -
            #include <cutlass/numeric_types.h>
         | 
| 12 | 
            -
            #include <cutlass/numeric_conversion.h>
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            #include "seqlen.h"
         | 
| 15 | 
            -
            #include "utils.h"
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            namespace flash {
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            using namespace cute;
         | 
| 20 | 
            -
             | 
| 21 | 
            -
            template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, bool Clear_dQaccum, bool Varlen>
         | 
| 22 | 
            -
            class FlashAttnBwdPreprocess {
         | 
| 23 | 
            -
             | 
| 24 | 
            -
            public:
         | 
| 25 | 
            -
             | 
| 26 | 
            -
                // Type Aliases
         | 
| 27 | 
            -
                using TileShape_MK = TileShape_MK_;
         | 
| 28 | 
            -
                using ArchTag = ArchTag_;
         | 
| 29 | 
            -
             | 
| 30 | 
            -
                static_assert(std::is_same_v<Element, cutlass::half_t> && ArchTag::kMinComputeCapability >= 75 ||
         | 
| 31 | 
            -
                              std::is_same_v<Element, cutlass::bfloat16_t> && ArchTag::kMinComputeCapability >= 80 ||
         | 
| 32 | 
            -
                              std::is_same_v<Element, cutlass::float_e4m3_t> && ArchTag::kMinComputeCapability >= 89);
         | 
| 33 | 
            -
             | 
| 34 | 
            -
                static constexpr uint32_t MaxThreadsPerBlock = 256;
         | 
| 35 | 
            -
                static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
         | 
| 36 | 
            -
                static constexpr int SharedStorageSize = 0;
         | 
| 37 | 
            -
             | 
| 38 | 
            -
                static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
         | 
| 39 | 
            -
                static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
         | 
| 40 | 
            -
                static constexpr int kBlockM = get<0>(TileShape_MK{});
         | 
| 41 | 
            -
                static constexpr int kHeadDim = get<1>(TileShape_MK{});
         | 
| 42 | 
            -
                // We want kBlockKGmem to be a power of 2 so that when we do the summing,
         | 
| 43 | 
            -
                // it's just between threads in the same warp
         | 
| 44 | 
            -
                static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
         | 
| 45 | 
            -
                static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
         | 
| 46 | 
            -
                static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
         | 
| 47 | 
            -
                using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
         | 
| 48 | 
            -
                                              Stride<Int<kGmemThreadsPerRow>, _1>>;
         | 
| 49 | 
            -
                using GmemTiledCopy = decltype(
         | 
| 50 | 
            -
                    make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
         | 
| 51 | 
            -
                                    GmemLayoutAtom{},
         | 
| 52 | 
            -
                                    Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per load
         | 
| 53 | 
            -
             | 
| 54 | 
            -
                static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
         | 
| 55 | 
            -
                static_assert((kBlockM * kHeadDim / kGmemElemsPerLoadAccum) % MaxThreadsPerBlock == 0, "MaxThreadsPerBlock must divide kBlockM * kHeadDim / kGmemElemsPerLoadAccum");
         | 
| 56 | 
            -
                using GmemLayoutAtomAccum = Layout<Shape<Int<MaxThreadsPerBlock>>>;
         | 
| 57 | 
            -
                using GmemTiledCopyAccum = decltype(
         | 
| 58 | 
            -
                    make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
         | 
| 59 | 
            -
                                    GmemLayoutAtomAccum{},
         | 
| 60 | 
            -
                                    Layout<Shape<Int<kGmemElemsPerLoadAccum>>>{}));  // Val layout, 4 vals per store
         | 
| 61 | 
            -
             | 
| 62 | 
            -
                using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen_q, d, head, batch)
         | 
| 63 | 
            -
                using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
         | 
| 64 | 
            -
                using ShapedPsum = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen_q, head, batch)
         | 
| 65 | 
            -
                using StridedPsum = cute::Stride<_1, int64_t, int64_t>;
         | 
| 66 | 
            -
                using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen_q * d, head, batch)
         | 
| 67 | 
            -
                using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
         | 
| 68 | 
            -
             | 
| 69 | 
            -
                // Device side arguments
         | 
| 70 | 
            -
                struct Arguments {
         | 
| 71 | 
            -
                    Element const* ptr_O;
         | 
| 72 | 
            -
                    ShapeO const shape_O;
         | 
| 73 | 
            -
                    StrideO const stride_O;
         | 
| 74 | 
            -
                    Element const* ptr_dO;
         | 
| 75 | 
            -
                    StrideO const stride_dO;
         | 
| 76 | 
            -
                    float* ptr_dPsum;
         | 
| 77 | 
            -
                    ShapedPsum const shape_dPsum;
         | 
| 78 | 
            -
                    StridedPsum const stride_dPsum;
         | 
| 79 | 
            -
                    float const* ptr_LSE;
         | 
| 80 | 
            -
                    StridedPsum const stride_LSE;
         | 
| 81 | 
            -
                    float *ptr_LSE_log2;
         | 
| 82 | 
            -
                    StridedPsum const stride_LSE_log2;
         | 
| 83 | 
            -
                    ElementAccum* ptr_dQaccum;
         | 
| 84 | 
            -
                    ShapedQaccum const shape_dQaccum;
         | 
| 85 | 
            -
                    StridedQaccum const stride_dQaccum;
         | 
| 86 | 
            -
                    int num_batch;  // We need this to know the size of dq_semaphore in case of varlen
         | 
| 87 | 
            -
                    int* dq_semaphore;
         | 
| 88 | 
            -
                    int const* cu_seqlens = nullptr;
         | 
| 89 | 
            -
                    int const* seqused = nullptr;
         | 
| 90 | 
            -
                };
         | 
| 91 | 
            -
             | 
| 92 | 
            -
                // Kernel entry point API
         | 
| 93 | 
            -
                struct Params {
         | 
| 94 | 
            -
                    Element const* ptr_O;
         | 
| 95 | 
            -
                    ShapeO const shape_O;
         | 
| 96 | 
            -
                    StrideO const stride_O;
         | 
| 97 | 
            -
                    Element const* ptr_dO;
         | 
| 98 | 
            -
                    StrideO const stride_dO;
         | 
| 99 | 
            -
                    float* ptr_dPsum;
         | 
| 100 | 
            -
                    ShapedPsum const shape_dPsum;
         | 
| 101 | 
            -
                    StridedPsum const stride_dPsum;
         | 
| 102 | 
            -
                    float const* ptr_LSE;
         | 
| 103 | 
            -
                    StridedPsum const stride_LSE;
         | 
| 104 | 
            -
                    float* ptr_LSE_log2;
         | 
| 105 | 
            -
                    StridedPsum const stride_LSE_log2;
         | 
| 106 | 
            -
                    ElementAccum* ptr_dQaccum;
         | 
| 107 | 
            -
                    ShapedQaccum const shape_dQaccum;
         | 
| 108 | 
            -
                    StridedQaccum const stride_dQaccum;
         | 
| 109 | 
            -
                    int num_batch;
         | 
| 110 | 
            -
                    int* dq_semaphore;
         | 
| 111 | 
            -
                    int const* cu_seqlens = nullptr;
         | 
| 112 | 
            -
                    int const* seqused = nullptr;
         | 
| 113 | 
            -
                };
         | 
| 114 | 
            -
             | 
| 115 | 
            -
                // Convert to underlying arguments. In this case, a simple copy for the aliased type.
         | 
| 116 | 
            -
                static
         | 
| 117 | 
            -
                Params
         | 
| 118 | 
            -
                to_underlying_arguments(Arguments const& args) {
         | 
| 119 | 
            -
                    return {
         | 
| 120 | 
            -
                        args.ptr_O,
         | 
| 121 | 
            -
                        args.shape_O,
         | 
| 122 | 
            -
                        args.stride_O,
         | 
| 123 | 
            -
                        args.ptr_dO,
         | 
| 124 | 
            -
                        args.stride_dO,
         | 
| 125 | 
            -
                        args.ptr_dPsum,
         | 
| 126 | 
            -
                        args.shape_dPsum,
         | 
| 127 | 
            -
                        args.stride_dPsum,
         | 
| 128 | 
            -
                        args.ptr_LSE,
         | 
| 129 | 
            -
                        args.stride_LSE,
         | 
| 130 | 
            -
                        args.ptr_LSE_log2,
         | 
| 131 | 
            -
                        args.stride_LSE_log2,
         | 
| 132 | 
            -
                        args.ptr_dQaccum,
         | 
| 133 | 
            -
                        args.shape_dQaccum,
         | 
| 134 | 
            -
                        args.stride_dQaccum,
         | 
| 135 | 
            -
                        args.num_batch,
         | 
| 136 | 
            -
                        args.dq_semaphore,
         | 
| 137 | 
            -
                        args.cu_seqlens,
         | 
| 138 | 
            -
                        args.seqused
         | 
| 139 | 
            -
                    };
         | 
| 140 | 
            -
                }
         | 
| 141 | 
            -
             | 
| 142 | 
            -
                CUTLASS_DEVICE
         | 
| 143 | 
            -
                void
         | 
| 144 | 
            -
                operator()(Params const& params, [[maybe_unused]] char* smem_buf) {
         | 
| 145 | 
            -
             | 
| 146 | 
            -
                    static constexpr int kBlockM = get<0>(TileShape_MK{});
         | 
| 147 | 
            -
             | 
| 148 | 
            -
                    int const thread_idx = threadIdx.x;
         | 
| 149 | 
            -
                    int const m_block = blockIdx.x;
         | 
| 150 | 
            -
                    int const bidh = blockIdx.y;
         | 
| 151 | 
            -
                    int const bidb = blockIdx.z;
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                    flash::SeqlenInfo<Varlen, kBlockM> seqlen_info(bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused);
         | 
| 154 | 
            -
                    bool const is_varlen = Varlen && params.cu_seqlens;
         | 
| 155 | 
            -
                    int const seqlen_o = seqlen_info.seqlen;
         | 
| 156 | 
            -
                    if (is_varlen && m_block * kBlockM >= seqlen_o) { return; }
         | 
| 157 | 
            -
             | 
| 158 | 
            -
                    Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0);
         | 
| 159 | 
            -
                    Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{}));  // (M, K)
         | 
| 160 | 
            -
                    Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0);
         | 
| 161 | 
            -
                    Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{}));  // (M, K)
         | 
| 162 | 
            -
             | 
| 163 | 
            -
                    auto shape_LSE = select<0, 2, 3>(params.shape_O);
         | 
| 164 | 
            -
                    Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0);
         | 
| 165 | 
            -
                    Tensor gLSE = local_tile(cute::domain_offset(make_coord(seqlen_info.offset), mLSE), Shape<Int<kBlockM>>{}, make_coord(m_block));
         | 
| 166 | 
            -
                    static_assert(kBlockM <= MaxThreadsPerBlock);
         | 
| 167 | 
            -
                    float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY;
         | 
| 168 | 
            -
             | 
| 169 | 
            -
                    GmemTiledCopy gmem_tiled_copy_O;
         | 
| 170 | 
            -
                    auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
         | 
| 171 | 
            -
             | 
| 172 | 
            -
                    Tensor tOgO = gmem_thr_copy_O.partition_S(gO);
         | 
| 173 | 
            -
                    Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO);
         | 
| 174 | 
            -
                    // Construct identity layout for gO
         | 
| 175 | 
            -
                    Tensor cO = cute::make_identity_tensor(TileShape_MK{});  // (BLK_M,BLK_K) -> (blk_m,blk_k)
         | 
| 176 | 
            -
                    // Repeat the partitioning with identity layouts
         | 
| 177 | 
            -
                    Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
         | 
| 178 | 
            -
                    Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
         | 
| 179 | 
            -
                    #pragma unroll
         | 
| 180 | 
            -
                    for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
         | 
| 181 | 
            -
             | 
| 182 | 
            -
                    // (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128)
         | 
| 183 | 
            -
                    Tensor tOrO = make_fragment_like(tOgO);
         | 
| 184 | 
            -
                    Tensor tOrdO = make_fragment_like(tOgdO);
         | 
| 185 | 
            -
                    flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
         | 
| 186 | 
            -
                        gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM
         | 
| 187 | 
            -
                    );
         | 
| 188 | 
            -
                    flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
         | 
| 189 | 
            -
                        gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM
         | 
| 190 | 
            -
                    );
         | 
| 191 | 
            -
                    // if (threadIdx.x == 222) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_o = %d, m_block = %d, seqlen_o - m_block * kBlockM = %d, tOgO addr = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_o, m_block, seqlen_o - m_block * kBlockM, &tOgO(0));}
         | 
| 192 | 
            -
             | 
| 193 | 
            -
                    // Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64))
         | 
| 194 | 
            -
                    Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout())));
         | 
| 195 | 
            -
                    Tensor tOrO_l = make_tensor(tOrO.data(), l);
         | 
| 196 | 
            -
                    Tensor o_fp32 = make_tensor_like<float>(tOrO_l);
         | 
| 197 | 
            -
                    flash::convert_type_out(tOrO_l, o_fp32);
         | 
| 198 | 
            -
                    Tensor tOrdO_l = make_tensor(tOrdO.data(), l);
         | 
| 199 | 
            -
                    Tensor do_fp32 = make_tensor_like<float>(tOrdO_l);
         | 
| 200 | 
            -
                    flash::convert_type_out(tOrdO_l, do_fp32);
         | 
| 201 | 
            -
                    // Sum across the last dimension
         | 
| 202 | 
            -
                    Tensor dP_sum = make_tensor<float>(make_shape(size<0>(o_fp32)));
         | 
| 203 | 
            -
                    #pragma unroll
         | 
| 204 | 
            -
                    for (int mi = 0; mi < size<0>(o_fp32); ++mi) {
         | 
| 205 | 
            -
                        float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
         | 
| 206 | 
            -
                        #pragma unroll
         | 
| 207 | 
            -
                        for (int ni = 1; ni < size<1>(o_fp32); ni++) {
         | 
| 208 | 
            -
                            dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
         | 
| 209 | 
            -
                        }
         | 
| 210 | 
            -
                        flash::SumOp<float> sum_op;
         | 
| 211 | 
            -
                        dP_sum(mi) = flash::Allreduce<kGmemThreadsPerRow>::run(dP_sum_cur, sum_op);
         | 
| 212 | 
            -
                    }
         | 
| 213 | 
            -
             | 
| 214 | 
            -
                    Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0);
         | 
| 215 | 
            -
                    Tensor gdPsum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mdPsum), Shape<Int<kBlockM>>{}, make_coord(m_block));
         | 
| 216 | 
            -
                    if (get<1>(tOcO(_0{}, _0{}, _0{})) == 0) {
         | 
| 217 | 
            -
                        #pragma unroll
         | 
| 218 | 
            -
                        for (int mi = 0; mi < size(dP_sum); ++mi) {
         | 
| 219 | 
            -
                            int const row = get<0>(tOcO(_0{}, mi, _0{}));
         | 
| 220 | 
            -
                            gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0;
         | 
| 221 | 
            -
                        }
         | 
| 222 | 
            -
                    }
         | 
| 223 | 
            -
             | 
| 224 | 
            -
                    int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM);
         | 
| 225 | 
            -
                    Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0);
         | 
| 226 | 
            -
                    Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mLSElog2), Shape<Int<kBlockM>>{}, make_coord(m_block));
         | 
| 227 | 
            -
                    if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) {
         | 
| 228 | 
            -
                        gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E);
         | 
| 229 | 
            -
                    }
         | 
| 230 | 
            -
             | 
| 231 | 
            -
                    if constexpr (Clear_dQaccum) {
         | 
| 232 | 
            -
                        Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
         | 
| 233 | 
            -
                        Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block));
         | 
| 234 | 
            -
                        GmemTiledCopyAccum gmem_tiled_copy_dQaccum;
         | 
| 235 | 
            -
                        auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx);
         | 
| 236 | 
            -
                        Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
         | 
| 237 | 
            -
                        Tensor zero = make_fragment_like(tdQgdQaccum);
         | 
| 238 | 
            -
                        clear(zero);
         | 
| 239 | 
            -
                        cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, zero, tdQgdQaccum);
         | 
| 240 | 
            -
                    }
         | 
| 241 | 
            -
             | 
| 242 | 
            -
                    if (params.dq_semaphore != nullptr && thread_idx == 0) {
         | 
| 243 | 
            -
                        int const num_batch = params.num_batch;
         | 
| 244 | 
            -
                        int const num_head = get<2>(params.shape_O);
         | 
| 245 | 
            -
                        params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0;
         | 
| 246 | 
            -
                    }
         | 
| 247 | 
            -
             | 
| 248 | 
            -
                }
         | 
| 249 | 
            -
             | 
| 250 | 
            -
            };
         | 
| 251 | 
            -
             | 
| 252 | 
            -
            } // namespace flash
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash_fwd_combine.cu
    DELETED
    
    | @@ -1,13 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different head dimensions to different files to speed up compilation.
         | 
| 3 | 
            -
             | 
| 4 | 
            -
            #include "flash_fwd_combine_launch_template.h"
         | 
| 5 | 
            -
             | 
| 6 | 
            -
            template void run_mha_fwd_combine_<float, float, 64>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
         | 
| 7 | 
            -
            template void run_mha_fwd_combine_<float, float, 128>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            template void run_mha_fwd_combine_<cutlass::half_t, float, 64>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
         | 
| 10 | 
            -
            template void run_mha_fwd_combine_<cutlass::half_t, float, 128>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
         | 
| 11 | 
            -
             | 
| 12 | 
            -
            template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
         | 
| 13 | 
            -
            template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash_fwd_combine_kernel.h
    DELETED
    
    | @@ -1,702 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include "cute/tensor.hpp"
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            #include <cutlass/cutlass.h>
         | 
| 10 | 
            -
            #include <cutlass/arch/memory.h>
         | 
| 11 | 
            -
            #include <cutlass/array.h>
         | 
| 12 | 
            -
            #include <cutlass/numeric_types.h>
         | 
| 13 | 
            -
            #include <cutlass/numeric_conversion.h>
         | 
| 14 | 
            -
             | 
| 15 | 
            -
            #include "cutlass/arch/grid_dependency_control.h"
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            #include "seqlen.h"
         | 
| 18 | 
            -
            #include "utils.h"
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            namespace flash {
         | 
| 21 | 
            -
             | 
| 22 | 
            -
            using namespace cute;
         | 
| 23 | 
            -
             | 
| 24 | 
            -
            template <class TileShape_MK_, int kLogMaxSplits_, int kNThreads, int AlignmentLSE_,
         | 
| 25 | 
            -
                      bool Is_even_K, bool Varlen, class Element, class ElementPartial, class ArchTag_>
         | 
| 26 | 
            -
            class FlashAttnFwdCombine {
         | 
| 27 | 
            -
             | 
| 28 | 
            -
            public:
         | 
| 29 | 
            -
             | 
| 30 | 
            -
                // Type Aliases
         | 
| 31 | 
            -
                using TileShape_MK = TileShape_MK_;
         | 
| 32 | 
            -
                using ArchTag = ArchTag_;
         | 
| 33 | 
            -
                static constexpr int kMaxSplits = 1 << kLogMaxSplits_;
         | 
| 34 | 
            -
                static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float)));
         | 
| 35 | 
            -
                static_assert(AlignmentLSE >= 1);
         | 
| 36 | 
            -
                static constexpr int kStages = 4;
         | 
| 37 | 
            -
             | 
| 38 | 
            -
                static_assert(ArchTag::kMinComputeCapability >= 75);
         | 
| 39 | 
            -
                static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;
         | 
| 40 | 
            -
             | 
| 41 | 
            -
                static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
         | 
| 42 | 
            -
                static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
         | 
| 43 | 
            -
             | 
| 44 | 
            -
                static constexpr int kBlockM = get<0>(TileShape_MK{});
         | 
| 45 | 
            -
                static constexpr int kBlockK = get<1>(TileShape_MK{});
         | 
| 46 | 
            -
             | 
| 47 | 
            -
                static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial);
         | 
| 48 | 
            -
                static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad");
         | 
| 49 | 
            -
                static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32);
         | 
| 50 | 
            -
                static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
         | 
| 51 | 
            -
                static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
         | 
| 52 | 
            -
                using GmemCopyAtom = std::conditional_t<
         | 
| 53 | 
            -
                    Has_cp_async,
         | 
| 54 | 
            -
                    cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, ElementPartial>,
         | 
| 55 | 
            -
                    cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>
         | 
| 56 | 
            -
                >;
         | 
| 57 | 
            -
                using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
         | 
| 58 | 
            -
                                              Stride<Int<kGmemThreadsPerRow>, _1>>;
         | 
| 59 | 
            -
                static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);
         | 
| 60 | 
            -
                using GmemTiledCopyAccum = decltype(
         | 
| 61 | 
            -
                    make_tiled_copy(GmemCopyAtom{},
         | 
| 62 | 
            -
                                    GmemLayoutAtom{},
         | 
| 63 | 
            -
                                    Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 4 vals per load
         | 
| 64 | 
            -
                using GmemTiledCopy = decltype(
         | 
| 65 | 
            -
                    make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
         | 
| 66 | 
            -
                                    GmemLayoutAtom{},
         | 
| 67 | 
            -
                                    Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 4 vals per load
         | 
| 68 | 
            -
             | 
| 69 | 
            -
                using AlignmentTypeLSE = cute::uint_byte_t<static_cast<int>(sizeof(float)) * AlignmentLSE>;
         | 
| 70 | 
            -
                static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float);
         | 
| 71 | 
            -
                static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, "kBlockM must be a multiple of kGmemElemsPerLoadLSE");
         | 
| 72 | 
            -
                static_assert(kBlockM % 8 == 0, "kBlockM must be a multiple of 8");
         | 
| 73 | 
            -
                static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8)));
         | 
| 74 | 
            -
                static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE;
         | 
| 75 | 
            -
                static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE");
         | 
| 76 | 
            -
                using GmemLayoutAtomLSE = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRowLSE>, Int<kGmemThreadsPerRowLSE>>,
         | 
| 77 | 
            -
                                                 Stride<Int<kGmemThreadsPerRowLSE>, _1>>;
         | 
| 78 | 
            -
                static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0);
         | 
| 79 | 
            -
                using GmemCopyAtomLSE = std::conditional_t<
         | 
| 80 | 
            -
                    Has_cp_async,
         | 
| 81 | 
            -
                    cute::Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeLSE>, float>,
         | 
| 82 | 
            -
                    cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<AlignmentLSE * sizeof(float) * 8>, float>
         | 
| 83 | 
            -
                >;
         | 
| 84 | 
            -
                using GmemTiledCopyLSE = decltype(
         | 
| 85 | 
            -
                    make_tiled_copy(GmemCopyAtomLSE{},
         | 
| 86 | 
            -
                                    GmemLayoutAtomLSE{},
         | 
| 87 | 
            -
                                    Layout<Shape<_1, Int<kGmemElemsPerLoadLSE>>>{}));  // Val layout, 4 vals per load
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                // Otherwise we get IMA when some threads access sLSE, as we're not doing any masking
         | 
| 90 | 
            -
                static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, "kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE");
         | 
| 91 | 
            -
                // This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
         | 
| 92 | 
            -
                using SmemLSESwizzle = std::conditional_t<
         | 
| 93 | 
            -
                    kBlockMSmem == 8,
         | 
| 94 | 
            -
                    Swizzle<5, 0, 5>,
         | 
| 95 | 
            -
                    std::conditional_t<kBlockMSmem == 16, Swizzle<4, 0, 4>, Swizzle<3, 2, 3>>
         | 
| 96 | 
            -
                >;
         | 
| 97 | 
            -
                using SmemLayoutAtomLSE =
         | 
| 98 | 
            -
                    decltype(composition(SmemLSESwizzle{},
         | 
| 99 | 
            -
                             Layout<Shape<Int<8>, Int<kBlockMSmem>>,
         | 
| 100 | 
            -
                             Stride<Int<kBlockMSmem>, _1>>{}));
         | 
| 101 | 
            -
                using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape<Int<kMaxSplits>, Int<kBlockM>>{}));
         | 
| 102 | 
            -
             | 
| 103 | 
            -
                using SmemLayoutO = Layout<Shape<Int<kBlockM>, Int<kBlockK>, Int<kStages>>,
         | 
| 104 | 
            -
                                           Stride<Int<kBlockK>, _1, Int<kBlockM * kBlockK>>>;
         | 
| 105 | 
            -
             | 
| 106 | 
            -
                // We want each column (kMaxSplits) to be processed by threads in the same warp.
         | 
| 107 | 
            -
                // To reduce the number of shuffles, we want as few threads on the same column as possible.
         | 
| 108 | 
            -
                // E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column
         | 
| 109 | 
            -
                // have have 64 such quads.
         | 
| 110 | 
            -
                static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, "MaxThreadsPerBlock must be a multiple of kBlockMSmem");
         | 
| 111 | 
            -
                static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem;
         | 
| 112 | 
            -
                static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, "kSmemThreadsPerColLSEt must divide NumThreadsPerWarp");
         | 
| 113 | 
            -
                using S2RLayoutAtomLSE = Layout<Shape<Int<kSmemThreadsPerColLSEt>, Int<MaxThreadsPerBlock / kSmemThreadsPerColLSEt>>>;
         | 
| 114 | 
            -
                using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, float>{}, S2RLayoutAtomLSE{}, Layout<_1>{}));
         | 
| 115 | 
            -
             | 
| 116 | 
            -
                using ShapeOPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>;  // (seqlen, d, num_splits, head, batch)
         | 
| 117 | 
            -
                using StrideOPartial = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
         | 
| 118 | 
            -
                using ShapeLSEPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen, num_splits, head, batch)
         | 
| 119 | 
            -
                using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>;  // (seqlen, num_splits, head, batch)
         | 
| 120 | 
            -
                using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen, d, head, batch)
         | 
| 121 | 
            -
                using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
         | 
| 122 | 
            -
                using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen, head, batch)
         | 
| 123 | 
            -
                using StrideLSE = cute::Stride<_1, int64_t, int64_t>;  // (seqlen, head, batch)
         | 
| 124 | 
            -
             | 
| 125 | 
            -
                struct BlockCoord {
         | 
| 126 | 
            -
                    int block_m;
         | 
| 127 | 
            -
                    int block_k;
         | 
| 128 | 
            -
                    int bidb;
         | 
| 129 | 
            -
                };
         | 
| 130 | 
            -
             | 
| 131 | 
            -
                struct SharedStorage : cute::aligned_struct<128> {
         | 
| 132 | 
            -
                    cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
         | 
| 133 | 
            -
                    cute::array_aligned<int, kBlockM> smem_max_valid_split;
         | 
| 134 | 
            -
                    cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
         | 
| 135 | 
            -
                    BlockCoord block_coord;
         | 
| 136 | 
            -
                };
         | 
| 137 | 
            -
             | 
| 138 | 
            -
                static constexpr int SharedStorageSize = sizeof(SharedStorage);
         | 
| 139 | 
            -
             | 
| 140 | 
            -
                // Device side arguments
         | 
| 141 | 
            -
                struct Arguments {
         | 
| 142 | 
            -
                    int b;
         | 
| 143 | 
            -
                    ElementPartial const* const ptr_O_partial;
         | 
| 144 | 
            -
                    ShapeOPartial const shape_O_partial;
         | 
| 145 | 
            -
                    StrideOPartial const stride_O_partial;
         | 
| 146 | 
            -
                    float const* const ptr_LSE_partial;
         | 
| 147 | 
            -
                    ShapeLSEPartial const shape_LSE_partial;
         | 
| 148 | 
            -
                    StrideLSEPartial const stride_LSE_partial;
         | 
| 149 | 
            -
                    Element* const ptr_O;
         | 
| 150 | 
            -
                    StrideO const stride_O;
         | 
| 151 | 
            -
                    float* const ptr_LSE;
         | 
| 152 | 
            -
                    StrideLSE const stride_LSE;
         | 
| 153 | 
            -
                    int const* const cu_seqlens = nullptr;
         | 
| 154 | 
            -
                    int const* const seqused = nullptr;
         | 
| 155 | 
            -
                    int const* const num_splits_dynamic_ptr = nullptr;
         | 
| 156 | 
            -
                    int* const semaphore_to_reset = nullptr;
         | 
| 157 | 
            -
                };
         | 
| 158 | 
            -
             | 
| 159 | 
            -
                // Kernel entry point API
         | 
| 160 | 
            -
                struct CollectiveParams {
         | 
| 161 | 
            -
                    int b;
         | 
| 162 | 
            -
                    ElementPartial const* const ptr_O_partial;
         | 
| 163 | 
            -
                    ShapeOPartial const shape_O_partial;
         | 
| 164 | 
            -
                    StrideOPartial const stride_O_partial;
         | 
| 165 | 
            -
                    float const* const ptr_LSE_partial;
         | 
| 166 | 
            -
                    ShapeLSEPartial const shape_LSE_partial;
         | 
| 167 | 
            -
                    StrideLSEPartial const stride_LSE_partial;
         | 
| 168 | 
            -
                    Element* const ptr_O;
         | 
| 169 | 
            -
                    StrideO const stride_O;
         | 
| 170 | 
            -
                    float* const ptr_LSE;
         | 
| 171 | 
            -
                    StrideLSE const stride_LSE;
         | 
| 172 | 
            -
                    cutlass::FastDivmod seqlen_divmod, head_divmod;
         | 
| 173 | 
            -
                    int const* const cu_seqlens = nullptr;
         | 
| 174 | 
            -
                    int const* const seqused = nullptr;
         | 
| 175 | 
            -
                    int const* const num_splits_dynamic_ptr = nullptr;
         | 
| 176 | 
            -
                    int* const semaphore_to_reset = nullptr;
         | 
| 177 | 
            -
                };
         | 
| 178 | 
            -
             | 
| 179 | 
            -
                // Convert to underlying arguments. In this case, a simple copy for the aliased type.
         | 
| 180 | 
            -
                static
         | 
| 181 | 
            -
                CollectiveParams
         | 
| 182 | 
            -
                to_underlying_arguments(Arguments const& args) {
         | 
| 183 | 
            -
                    assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);
         | 
| 184 | 
            -
                    return {
         | 
| 185 | 
            -
                        args.b,
         | 
| 186 | 
            -
                        args.ptr_O_partial,
         | 
| 187 | 
            -
                        args.shape_O_partial,
         | 
| 188 | 
            -
                        args.stride_O_partial,
         | 
| 189 | 
            -
                        args.ptr_LSE_partial,
         | 
| 190 | 
            -
                        args.shape_LSE_partial,
         | 
| 191 | 
            -
                        args.stride_LSE_partial,
         | 
| 192 | 
            -
                        args.ptr_O,
         | 
| 193 | 
            -
                        args.stride_O,
         | 
| 194 | 
            -
                        args.ptr_LSE,
         | 
| 195 | 
            -
                        args.stride_LSE,
         | 
| 196 | 
            -
                        cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)),
         | 
| 197 | 
            -
                        args.cu_seqlens,
         | 
| 198 | 
            -
                        args.seqused,
         | 
| 199 | 
            -
                        args.num_splits_dynamic_ptr,
         | 
| 200 | 
            -
                        args.semaphore_to_reset
         | 
| 201 | 
            -
                    };
         | 
| 202 | 
            -
                }
         | 
| 203 | 
            -
             | 
| 204 | 
            -
                struct SchedulerArguments {
         | 
| 205 | 
            -
                    int b;
         | 
| 206 | 
            -
                    int seqlen_q;
         | 
| 207 | 
            -
                    int total_q;
         | 
| 208 | 
            -
                    int num_heads;
         | 
| 209 | 
            -
                    int dv;
         | 
| 210 | 
            -
                    int const* cu_seqlens_q;
         | 
| 211 | 
            -
                    int const* seqused_q;
         | 
| 212 | 
            -
                };
         | 
| 213 | 
            -
             | 
| 214 | 
            -
                struct StaticTileScheduler {
         | 
| 215 | 
            -
                    struct Params {};
         | 
| 216 | 
            -
                    static Params to_underlying_arguments(SchedulerArguments const& args) { return {}; }
         | 
| 217 | 
            -
             | 
| 218 | 
            -
                    SharedStorage& shared_storage;
         | 
| 219 | 
            -
                    CUTE_DEVICE StaticTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {}
         | 
| 220 | 
            -
             | 
| 221 | 
            -
                    static dim3 get_grid_shape(SchedulerArguments const& args) {
         | 
| 222 | 
            -
                        unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
         | 
| 223 | 
            -
                        unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM);
         | 
| 224 | 
            -
                        return {num_blocks_m, num_blocks_k, static_cast<unsigned int>(args.b)};
         | 
| 225 | 
            -
                    }
         | 
| 226 | 
            -
             | 
| 227 | 
            -
                    CUTE_DEVICE BlockCoord get_block_coord(Params const& params) {
         | 
| 228 | 
            -
                        int block_m = blockIdx.x;
         | 
| 229 | 
            -
                        int block_k = blockIdx.y;
         | 
| 230 | 
            -
                        int bidb = blockIdx.z;
         | 
| 231 | 
            -
                        return {block_m, block_k, bidb};
         | 
| 232 | 
            -
                    }
         | 
| 233 | 
            -
                };
         | 
| 234 | 
            -
             | 
| 235 | 
            -
                struct StaticVarlenTileScheduler {
         | 
| 236 | 
            -
                    //
         | 
| 237 | 
            -
                    // For varlen we have two Scheduling algos:
         | 
| 238 | 
            -
                    //  1) STANDARD, same as StaticTileScheduler
         | 
| 239 | 
            -
                    //  2) LINEARIZE_M_AND_BATCH, this flattens the tiled M dimension and
         | 
| 240 | 
            -
                    //     batch dimension into a linear tile index. The grid is then a
         | 
| 241 | 
            -
                    //     2D grid of (tile_id, k_block). We then map the linear tile id
         | 
| 242 | 
            -
                    //     to (m_block, bidb) in the get_block_coord function. This mapping
         | 
| 243 | 
            -
                    //     is non-trivial since each batch element can have a different
         | 
| 244 | 
            -
                    //     number of m_blocks. This has overhead when computing the block
         | 
| 245 | 
            -
                    //     coordinates, but it is more efficient when prefills and decodes
         | 
| 246 | 
            -
                    //     are mixed since in that case the STANDARD scheduling algo will
         | 
| 247 | 
            -
                    //     have a lot of empty (no work) blocks in the grid.
         | 
| 248 | 
            -
                    //
         | 
| 249 | 
            -
             | 
| 250 | 
            -
                    enum SchedulingAlgo {
         | 
| 251 | 
            -
                        STANDARD,           // Same as StaticTileScheduler
         | 
| 252 | 
            -
                        LINEARIZE_M_AND_BATCH,  // Linearize the M and batch dimensions into a single tile index
         | 
| 253 | 
            -
                    };
         | 
| 254 | 
            -
             | 
| 255 | 
            -
                    struct Params {
         | 
| 256 | 
            -
                        int b;
         | 
| 257 | 
            -
                        int num_heads;
         | 
| 258 | 
            -
                        int const* const cu_seqlens_q;
         | 
| 259 | 
            -
                        int const* const seqused_q;
         | 
| 260 | 
            -
                        SchedulingAlgo algo;
         | 
| 261 | 
            -
                    };
         | 
| 262 | 
            -
             | 
| 263 | 
            -
                    SharedStorage& shared_storage;
         | 
| 264 | 
            -
                    CUTE_DEVICE StaticVarlenTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {}
         | 
| 265 | 
            -
             | 
| 266 | 
            -
                    static SchedulingAlgo choose_scheduling_algo(SchedulerArguments const& args) {
         | 
| 267 | 
            -
                        // Choose the scheduling algorithm based on how dense the grid of tiles that
         | 
| 268 | 
            -
                        // do actual work is. If the grid is more then 50% sparse, we linearize the M
         | 
| 269 | 
            -
                        // and batch. If the grid is more than 50% dense, we use the standard scheduling
         | 
| 270 | 
            -
                        // algorithm since its more efficient at calculating the block coordinates.
         | 
| 271 | 
            -
                        // NOTE: in varlen case args.seqlen_q is the max seqlen_q across all batches
         | 
| 272 | 
            -
                        // use lower bound to estimate when the density is more than 50%
         | 
| 273 | 
            -
                        int lower_bound_on_non_empty_tiles = cute::ceil_div(args.total_q, kBlockM);
         | 
| 274 | 
            -
                        int grid_size = args.b * cute::ceil_div(args.seqlen_q, kBlockM);
         | 
| 275 | 
            -
                        return 2 * lower_bound_on_non_empty_tiles >= grid_size ? 
         | 
| 276 | 
            -
                            SchedulingAlgo::STANDARD : 
         | 
| 277 | 
            -
                            SchedulingAlgo::LINEARIZE_M_AND_BATCH;
         | 
| 278 | 
            -
                    }
         | 
| 279 | 
            -
             | 
| 280 | 
            -
                    static Params to_underlying_arguments(SchedulerArguments const& args) { 
         | 
| 281 | 
            -
                        return {
         | 
| 282 | 
            -
                            args.b,
         | 
| 283 | 
            -
                            args.num_heads,
         | 
| 284 | 
            -
                            args.cu_seqlens_q,
         | 
| 285 | 
            -
                            args.seqused_q,
         | 
| 286 | 
            -
                            choose_scheduling_algo(args)
         | 
| 287 | 
            -
                        }; 
         | 
| 288 | 
            -
                    }
         | 
| 289 | 
            -
             | 
| 290 | 
            -
                    static dim3 get_grid_shape(SchedulerArguments const& args) {
         | 
| 291 | 
            -
                        unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
         | 
| 292 | 
            -
             | 
| 293 | 
            -
                        switch (choose_scheduling_algo(args)) {
         | 
| 294 | 
            -
                        case SchedulingAlgo::STANDARD: {
         | 
| 295 | 
            -
                            unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
         | 
| 296 | 
            -
                            unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM);
         | 
| 297 | 
            -
                            return {num_blocks_m, num_blocks_k, static_cast<unsigned int>(args.b)};
         | 
| 298 | 
            -
                        }
         | 
| 299 | 
            -
                        case SchedulingAlgo::LINEARIZE_M_AND_BATCH: {
         | 
| 300 | 
            -
                            // rough worst case upper bound on the number of blocks required 
         | 
| 301 | 
            -
                            //  (assuming each batch has an additional partial block)
         | 
| 302 | 
            -
                            unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b;
         | 
| 303 | 
            -
                            return {num_blocks_m, num_blocks_k, 1};
         | 
| 304 | 
            -
                        }}
         | 
| 305 | 
            -
             | 
| 306 | 
            -
                        // rough worst case upper bound on the number of blocks required 
         | 
| 307 | 
            -
                        //  (assuming each batch has an additional partial block)
         | 
| 308 | 
            -
                        unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b;
         | 
| 309 | 
            -
                        return {num_blocks_m, num_blocks_k, 1};
         | 
| 310 | 
            -
                    }
         | 
| 311 | 
            -
             | 
| 312 | 
            -
                    CUTE_DEVICE BlockCoord get_block_coord_linearized_m_and_batch(Params const& params) {
         | 
| 313 | 
            -
                        int num_heads = params.num_heads;
         | 
| 314 | 
            -
                        int curr_tile_id = blockIdx.x;
         | 
| 315 | 
            -
             | 
| 316 | 
            -
                        // Scan through the batches find the batch that contains the current
         | 
| 317 | 
            -
                        // tile_id. Compute using only the first warp of the block.
         | 
| 318 | 
            -
                        if (threadIdx.x < 32) {
         | 
| 319 | 
            -
                            // We compute linearized tile index start and ends for each batch
         | 
| 320 | 
            -
                            // in groups of 32 in parallel
         | 
| 321 | 
            -
                            int group_start_bidb = -(cutlass::NumThreadsPerWarp);
         | 
| 322 | 
            -
                            int group_end_bidb = 0;
         | 
| 323 | 
            -
                            int group_end_tile_id = 0;
         | 
| 324 | 
            -
                            int group_start_tile_id = 0;
         | 
| 325 | 
            -
                            int group_total_num_tiles = 0;
         | 
| 326 | 
            -
             | 
| 327 | 
            -
                            int local_num_m_blocks = 0;
         | 
| 328 | 
            -
                            int local_num_m_blocks_cumulative = 0;
         | 
| 329 | 
            -
             | 
| 330 | 
            -
                            do {
         | 
| 331 | 
            -
                                group_start_bidb += cutlass::NumThreadsPerWarp;
         | 
| 332 | 
            -
                                group_end_bidb += cutlass::NumThreadsPerWarp;
         | 
| 333 | 
            -
             | 
| 334 | 
            -
                                auto get_num_m_blocks = [&](int bidb) {
         | 
| 335 | 
            -
                                    if (bidb >= params.b) return 0;
         | 
| 336 | 
            -
                                    flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, 0, params.cu_seqlens_q, params.seqused_q};
         | 
| 337 | 
            -
                                    return cute::ceil_div(seqlen_info.seqlen * num_heads, Int<kBlockM>{}());
         | 
| 338 | 
            -
                                };
         | 
| 339 | 
            -
             | 
| 340 | 
            -
                                // Cumulative number of blocks for the next 31 batches
         | 
| 341 | 
            -
                                local_num_m_blocks = get_num_m_blocks(group_start_bidb + threadIdx.x);
         | 
| 342 | 
            -
                                local_num_m_blocks_cumulative = warp_prefix_sum(local_num_m_blocks);
         | 
| 343 | 
            -
                                // Total number of blocks for the next 32 batches
         | 
| 344 | 
            -
                                group_total_num_tiles = warp_shfl_get_last(local_num_m_blocks_cumulative);
         | 
| 345 | 
            -
                                
         | 
| 346 | 
            -
                                group_start_tile_id = group_end_tile_id;
         | 
| 347 | 
            -
                                group_end_tile_id += group_total_num_tiles;
         | 
| 348 | 
            -
                            } while (curr_tile_id >= group_end_tile_id && group_end_bidb < params.b);
         | 
| 349 | 
            -
                            
         | 
| 350 | 
            -
                            int local_batch_end_tile_id = group_start_tile_id + local_num_m_blocks_cumulative;
         | 
| 351 | 
            -
                            // Find the last batch idx in the group where `local_batch_end_tile_id <= curr_tile_id`
         | 
| 352 | 
            -
                            // these values below are now common to all threads in the warp
         | 
| 353 | 
            -
                            int batch_idx_in_group = warp_last_true_laneid(local_batch_end_tile_id <= curr_tile_id);
         | 
| 354 | 
            -
                            int batch_num_m_blocks = warp_shfl_get(local_num_m_blocks, batch_idx_in_group);
         | 
| 355 | 
            -
                            int batch_m_start_tile_id = group_start_tile_id + (batch_idx_in_group > 0 ? 
         | 
| 356 | 
            -
                                warp_shfl_get(local_num_m_blocks_cumulative, batch_idx_in_group - 1) : 0);
         | 
| 357 | 
            -
                            
         | 
| 358 | 
            -
                            int bidb = group_start_bidb + batch_idx_in_group;
         | 
| 359 | 
            -
                            int block_m = curr_tile_id - batch_m_start_tile_id;
         | 
| 360 | 
            -
                            // NOTE(lucas): not sure why this causes a block_k unused warning
         | 
| 361 | 
            -
                            //  just inlined `blockIdx.y` to suppress the warning
         | 
| 362 | 
            -
                            // int block_k = blockIdx.y;
         | 
| 363 | 
            -
                            // shared_storage.block_coord = {block_m, block_k, bidb};
         | 
| 364 | 
            -
                            BlockCoord block_coord{block_m, static_cast<int>(blockIdx.y), bidb};
         | 
| 365 | 
            -
                            if (threadIdx.x == 0) { shared_storage.block_coord = block_coord; }
         | 
| 366 | 
            -
                        }
         | 
| 367 | 
            -
             | 
| 368 | 
            -
                        __syncthreads();
         | 
| 369 | 
            -
                        return shared_storage.block_coord;
         | 
| 370 | 
            -
                    }
         | 
| 371 | 
            -
             | 
| 372 | 
            -
             | 
| 373 | 
            -
                    CUTE_DEVICE BlockCoord get_block_coord_standard(Params const& params) {
         | 
| 374 | 
            -
                        int block_m = blockIdx.x;
         | 
| 375 | 
            -
                        int block_k = blockIdx.y;
         | 
| 376 | 
            -
                        int bidb = blockIdx.z;
         | 
| 377 | 
            -
                        return {block_m, block_k, bidb};
         | 
| 378 | 
            -
                    }
         | 
| 379 | 
            -
             | 
| 380 | 
            -
                    CUTE_DEVICE BlockCoord get_block_coord(Params const& params) {
         | 
| 381 | 
            -
                        switch (params.algo) {
         | 
| 382 | 
            -
                            case SchedulingAlgo::STANDARD:
         | 
| 383 | 
            -
                                return get_block_coord_standard(params);
         | 
| 384 | 
            -
                            case SchedulingAlgo::LINEARIZE_M_AND_BATCH:
         | 
| 385 | 
            -
                                return get_block_coord_linearized_m_and_batch(params);
         | 
| 386 | 
            -
                        }
         | 
| 387 | 
            -
                        return {0, 0, 0};  // Should never reach here
         | 
| 388 | 
            -
                    }
         | 
| 389 | 
            -
                };
         | 
| 390 | 
            -
             | 
| 391 | 
            -
                using TileScheduler = std::conditional_t<
         | 
| 392 | 
            -
                    Varlen,
         | 
| 393 | 
            -
                    StaticVarlenTileScheduler,
         | 
| 394 | 
            -
                    StaticTileScheduler
         | 
| 395 | 
            -
                >;
         | 
| 396 | 
            -
             | 
| 397 | 
            -
                using SchedulerParams = typename TileScheduler::Params;
         | 
| 398 | 
            -
             | 
| 399 | 
            -
                struct Params {
         | 
| 400 | 
            -
                    CollectiveParams params;
         | 
| 401 | 
            -
                    SchedulerParams scheduler_params;
         | 
| 402 | 
            -
                };
         | 
| 403 | 
            -
             | 
| 404 | 
            -
                CUTLASS_DEVICE
         | 
| 405 | 
            -
                void
         | 
| 406 | 
            -
                operator()(Params const& kernel_params, char* smem_buf) {
         | 
| 407 | 
            -
                    CollectiveParams const& params = kernel_params.params;
         | 
| 408 | 
            -
             | 
| 409 | 
            -
                    SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
         | 
| 410 | 
            -
                    TileScheduler tile_scheduler{shared_storage};
         | 
| 411 | 
            -
             | 
| 412 | 
            -
                    Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});
         | 
| 413 | 
            -
                    Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});
         | 
| 414 | 
            -
                    Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});
         | 
| 415 | 
            -
             | 
| 416 | 
            -
                    int const thread_idx = threadIdx.x;
         | 
| 417 | 
            -
             | 
| 418 | 
            -
                    BlockCoord block_coord = tile_scheduler.get_block_coord(kernel_params.scheduler_params);
         | 
| 419 | 
            -
             | 
| 420 | 
            -
                    int const m_block = block_coord.block_m;
         | 
| 421 | 
            -
                    int const k_block = block_coord.block_k;
         | 
| 422 | 
            -
                    int const batch = block_coord.bidb;
         | 
| 423 | 
            -
             | 
| 424 | 
            -
                    if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {
         | 
| 425 | 
            -
                        cutlass::arch::wait_on_dependent_grids();
         | 
| 426 | 
            -
                        *params.semaphore_to_reset = 0;
         | 
| 427 | 
            -
                    }
         | 
| 428 | 
            -
             | 
| 429 | 
            -
                    flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
         | 
| 430 | 
            -
                    int const offset = seqlen_info.offset;
         | 
| 431 | 
            -
                    int const seqlen = seqlen_info.seqlen;
         | 
| 432 | 
            -
                    int max_idx = seqlen * get<2>(params.shape_LSE_partial);
         | 
| 433 | 
            -
             | 
| 434 | 
            -
                    bool block_coord_valid = 
         | 
| 435 | 
            -
                        block_coord.block_m < cute::ceil_div(max_idx, Int<kBlockM>{}) &&
         | 
| 436 | 
            -
                        block_coord.bidb < params.b;
         | 
| 437 | 
            -
                    if (!block_coord_valid) { return; }
         | 
| 438 | 
            -
             | 
| 439 | 
            -
                    int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
         | 
| 440 | 
            -
                    if (num_splits <= 1) { return; }
         | 
| 441 | 
            -
             | 
| 442 | 
            -
                    cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);
         | 
| 443 | 
            -
             | 
| 444 | 
            -
                    // Step 1: load LSE_partial from gmem -> smem
         | 
| 445 | 
            -
                    Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)),
         | 
| 446 | 
            -
                                                     select<1, 0, 2, 3>(params.shape_LSE_partial),
         | 
| 447 | 
            -
                                                     select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0);  // (num_splits, seqlen, head)
         | 
| 448 | 
            -
                    Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int<kGmemElemsPerLoadLSE>>{});
         | 
| 449 | 
            -
                    GmemTiledCopyLSE gmem_tiled_copy_LSE;
         | 
| 450 | 
            -
                    auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx);
         | 
| 451 | 
            -
                    Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE);
         | 
| 452 | 
            -
             | 
| 453 | 
            -
                    // Construct identity layout for sLSE
         | 
| 454 | 
            -
                    Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE)));    // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m)
         | 
| 455 | 
            -
                    // Repeat the partitioning with identity layouts
         | 
| 456 | 
            -
                    Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE);
         | 
| 457 | 
            -
             | 
| 458 | 
            -
                    cutlass::arch::wait_on_dependent_grids();
         | 
| 459 | 
            -
             | 
| 460 | 
            -
                    #pragma unroll
         | 
| 461 | 
            -
                    for (int m = 0; m < size<2>(tLSEcLSE); ++m) {
         | 
| 462 | 
            -
                        int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m)));
         | 
| 463 | 
            -
                        int idx = m_block * kBlockM + mi;
         | 
| 464 | 
            -
                        if (idx < max_idx) {
         | 
| 465 | 
            -
                            int m_idx, bidh;
         | 
| 466 | 
            -
                            if constexpr (!Varlen) {
         | 
| 467 | 
            -
                                bidh = params.seqlen_divmod.divmod(m_idx, idx);
         | 
| 468 | 
            -
                            } else {
         | 
| 469 | 
            -
                                bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
         | 
| 470 | 
            -
                            }
         | 
| 471 | 
            -
                            Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh);
         | 
| 472 | 
            -
                            #pragma unroll
         | 
| 473 | 
            -
                            for (int s = 0; s < size<1>(tLSEcLSE); ++s) {
         | 
| 474 | 
            -
                                int si = get<0>(tLSEcLSE(_0{}, s, _0{}));
         | 
| 475 | 
            -
                                // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast<int>(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);}
         | 
| 476 | 
            -
                                if (si < num_splits) {
         | 
| 477 | 
            -
                                    cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m));
         | 
| 478 | 
            -
                                } else {
         | 
| 479 | 
            -
                                    cute::fill(tLSEsLSE(_, s, m), -INFINITY);
         | 
| 480 | 
            -
                                }
         | 
| 481 | 
            -
                            }
         | 
| 482 | 
            -
                        } else {
         | 
| 483 | 
            -
                            // We don't need to zero out the rest of the LSEs, as we will not write the output to gmem
         | 
| 484 | 
            -
                            // cute::fill(tLSEsLSE(_, _, m), -INFINITY);
         | 
| 485 | 
            -
                        }
         | 
| 486 | 
            -
                    }
         | 
| 487 | 
            -
                    if constexpr (Has_cp_async) { cute::cp_async_fence(); }
         | 
| 488 | 
            -
             | 
| 489 | 
            -
                    // Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2.
         | 
| 490 | 
            -
                    // We want these async loads to be in flight as we compute the LSE.
         | 
| 491 | 
            -
                    GmemTiledCopyAccum gmem_tiled_copy_O_partial;
         | 
| 492 | 
            -
                    auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx);
         | 
| 493 | 
            -
                    // Construct identity layout for gO
         | 
| 494 | 
            -
                    Tensor cO = cute::make_identity_tensor(TileShape_MK{});  // (BLK_M,BLK_K) -> (blk_m,blk_k)
         | 
| 495 | 
            -
                    // Repeat the partitioning with identity layouts
         | 
| 496 | 
            -
                    Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO);
         | 
| 497 | 
            -
                    Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)),
         | 
| 498 | 
            -
                                                   params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0);  // (seqlen, d, num_splits, head)
         | 
| 499 | 
            -
             | 
| 500 | 
            -
                    // Precompute these values to avoid recomputing them in the loop
         | 
| 501 | 
            -
                    Tensor tOmidx = make_tensor<int>(make_shape(size<1>(tOcO)));
         | 
| 502 | 
            -
                    Tensor tObidh = make_tensor<int>(make_shape(size<1>(tOcO)));
         | 
| 503 | 
            -
                    Tensor tOrOptr = make_tensor<ElementPartial const*>(make_shape(size<1>(tOcO)));
         | 
| 504 | 
            -
                    #pragma unroll
         | 
| 505 | 
            -
                    for (int m = 0; m < size<1>(tOcO); ++m) {
         | 
| 506 | 
            -
                        int mi = get<0>(tOcO(_0{}, m, _0{}));
         | 
| 507 | 
            -
                        int idx = m_block * kBlockM + mi;
         | 
| 508 | 
            -
                        if constexpr (!Varlen) {
         | 
| 509 | 
            -
                            tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx);
         | 
| 510 | 
            -
                        } else {
         | 
| 511 | 
            -
                            tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx);
         | 
| 512 | 
            -
                        }
         | 
| 513 | 
            -
                        tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m));
         | 
| 514 | 
            -
                        if (idx >= max_idx) {
         | 
| 515 | 
            -
                            tObidh[m] = -1;
         | 
| 516 | 
            -
                        }
         | 
| 517 | 
            -
                    }
         | 
| 518 | 
            -
             | 
| 519 | 
            -
                    Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
         | 
| 520 | 
            -
                    if constexpr (!(Is_even_K)) {
         | 
| 521 | 
            -
                        #pragma unroll
         | 
| 522 | 
            -
                        for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; }
         | 
| 523 | 
            -
                    }
         | 
| 524 | 
            -
             | 
| 525 | 
            -
                    Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO);
         | 
| 526 | 
            -
             | 
| 527 | 
            -
                    auto load_O_partial = [&] (int split, int stage) {
         | 
| 528 | 
            -
                        Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage);
         | 
| 529 | 
            -
                        #pragma unroll
         | 
| 530 | 
            -
                        for (int m = 0; m < size<1>(tOcO); ++m) {
         | 
| 531 | 
            -
                            if (tObidh(m) >= 0)  {
         | 
| 532 | 
            -
                                Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout());
         | 
| 533 | 
            -
                                Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape<Int<kGmemElemsPerLoad>>{});
         | 
| 534 | 
            -
                                #pragma unroll
         | 
| 535 | 
            -
                                for (int k = 0; k < size<2>(tOcO); ++k) {
         | 
| 536 | 
            -
                                    int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
         | 
| 537 | 
            -
                                    if (Is_even_K || tOpO(k)) {
         | 
| 538 | 
            -
                                        cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k));
         | 
| 539 | 
            -
                                    }
         | 
| 540 | 
            -
                                }
         | 
| 541 | 
            -
                            }
         | 
| 542 | 
            -
                        }
         | 
| 543 | 
            -
                    };
         | 
| 544 | 
            -
             | 
| 545 | 
            -
                    for (int s = 0; s < kStages - 1; ++s) {
         | 
| 546 | 
            -
                        if (s < num_splits) { load_O_partial(s, s); }
         | 
| 547 | 
            -
                        if constexpr (Has_cp_async) { cute::cp_async_fence(); }
         | 
| 548 | 
            -
                    }
         | 
| 549 | 
            -
             | 
| 550 | 
            -
                    // Step 3: load and transpose LSE_partial from smem -> rmem
         | 
| 551 | 
            -
                    if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
         | 
| 552 | 
            -
                    __syncthreads();
         | 
| 553 | 
            -
             | 
| 554 | 
            -
                    S2RTiledCopyLSE s2r_tiled_copy_LSE;
         | 
| 555 | 
            -
                    auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx);
         | 
| 556 | 
            -
                    Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE);
         | 
| 557 | 
            -
                    Tensor ts2rrLSE = make_fragment_like(ts2rsLSE);
         | 
| 558 | 
            -
                    cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE);
         | 
| 559 | 
            -
             | 
| 560 | 
            -
                    // Step 4: compute the final LSE along the split dimension
         | 
| 561 | 
            -
                    Tensor lse_sum = make_tensor<float>(make_shape(size<2>(ts2rrLSE)));
         | 
| 562 | 
            -
                    Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE);
         | 
| 563 | 
            -
                    // We compute the max valid split for each row to short-circuit the computation later
         | 
| 564 | 
            -
                    Tensor max_valid_split = make_tensor<int>(make_shape(size<2>(ts2rrLSE)));
         | 
| 565 | 
            -
                    static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1);
         | 
| 566 | 
            -
                    #pragma unroll
         | 
| 567 | 
            -
                    for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
         | 
| 568 | 
            -
                        float lse_max = ts2rrLSE(_0{}, _0{}, m);
         | 
| 569 | 
            -
                        #pragma unroll
         | 
| 570 | 
            -
                        for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); }
         | 
| 571 | 
            -
                        MaxOp<float> max_op;
         | 
| 572 | 
            -
                        lse_max = Allreduce<kSmemThreadsPerColLSEt>::run(lse_max, max_op);
         | 
| 573 | 
            -
                        int max_valid_idx = -1;
         | 
| 574 | 
            -
                        #pragma unroll
         | 
| 575 | 
            -
                        for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
         | 
| 576 | 
            -
                            if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); }
         | 
| 577 | 
            -
                        }
         | 
| 578 | 
            -
                        MaxOp<int> max_int_op;
         | 
| 579 | 
            -
                        max_valid_split[m] = Allreduce<kSmemThreadsPerColLSEt>::run(max_valid_idx, max_int_op);
         | 
| 580 | 
            -
                        float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max;  // In case all local LSEs are -inf
         | 
| 581 | 
            -
                        float lse_sum_cur = 0.f;
         | 
| 582 | 
            -
                        #pragma unroll
         | 
| 583 | 
            -
                        for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
         | 
| 584 | 
            -
                            float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur);
         | 
| 585 | 
            -
                            lse_sum_cur += scale;
         | 
| 586 | 
            -
                            // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast<int>(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);}
         | 
| 587 | 
            -
                            // ts2rsLSE(_0{}, m, s) = scale;
         | 
| 588 | 
            -
                            ts2rrLSE(_0{}, s, m) = scale;
         | 
| 589 | 
            -
                        }
         | 
| 590 | 
            -
                        SumOp<float> sum_op;
         | 
| 591 | 
            -
                        lse_sum_cur = Allreduce<kSmemThreadsPerColLSEt>::run(lse_sum_cur, sum_op);
         | 
| 592 | 
            -
                        lse_sum(m) = logf(lse_sum_cur) + lse_max;
         | 
| 593 | 
            -
                        float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur;
         | 
| 594 | 
            -
                        #pragma unroll
         | 
| 595 | 
            -
                        for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; }
         | 
| 596 | 
            -
                    }
         | 
| 597 | 
            -
                    // Store the scales exp(lse - lse_logsum) back to smem
         | 
| 598 | 
            -
                    cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE);
         | 
| 599 | 
            -
             | 
| 600 | 
            -
                    // Store max_valid_split to smem
         | 
| 601 | 
            -
                    #pragma unroll
         | 
| 602 | 
            -
                    for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
         | 
| 603 | 
            -
                        if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) {  // Only the thread responsible for s=0 writes to smem
         | 
| 604 | 
            -
                            int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
         | 
| 605 | 
            -
                            if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; }
         | 
| 606 | 
            -
                        }
         | 
| 607 | 
            -
                    }
         | 
| 608 | 
            -
             | 
| 609 | 
            -
                    // Step 5: store final LSE back to gmem
         | 
| 610 | 
            -
                    if (k_block == 0) {
         | 
| 611 | 
            -
                        auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial);
         | 
| 612 | 
            -
                        Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0);
         | 
| 613 | 
            -
                        #pragma unroll
         | 
| 614 | 
            -
                        for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
         | 
| 615 | 
            -
                            if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) {  // Only the thread responsible for s=0 writes to gmem
         | 
| 616 | 
            -
                                int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
         | 
| 617 | 
            -
                                int idx = m_block * kBlockM + mi;
         | 
| 618 | 
            -
                                if (idx < max_idx) {
         | 
| 619 | 
            -
                                    int m_idx, bidh;
         | 
| 620 | 
            -
                                    if constexpr (!Varlen) {
         | 
| 621 | 
            -
                                        bidh = params.seqlen_divmod.divmod(m_idx, idx);
         | 
| 622 | 
            -
                                    } else {
         | 
| 623 | 
            -
                                        bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
         | 
| 624 | 
            -
                                    }
         | 
| 625 | 
            -
                                    // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m));
         | 
| 626 | 
            -
                                    mLSE(m_idx, bidh) = lse_sum(m);
         | 
| 627 | 
            -
                                }
         | 
| 628 | 
            -
                            }
         | 
| 629 | 
            -
                        }
         | 
| 630 | 
            -
                    }
         | 
| 631 | 
            -
             | 
| 632 | 
            -
                    // Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O
         | 
| 633 | 
            -
                    __syncthreads();
         | 
| 634 | 
            -
                    int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))];
         | 
| 635 | 
            -
                    #pragma unroll
         | 
| 636 | 
            -
                    for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); }
         | 
| 637 | 
            -
                    Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor<ElementPartial>(TileShape_MK{})).layout();
         | 
| 638 | 
            -
                    Tensor tOrOpartial = make_fragment_like<ElementPartial>(tOrOpartial_layout);
         | 
| 639 | 
            -
                    Tensor tOrO = make_fragment_like<float>(tOrOpartial);
         | 
| 640 | 
            -
                    clear(tOrO);
         | 
| 641 | 
            -
                    int stage_load = kStages - 1, stage_compute = 0;
         | 
| 642 | 
            -
                    #pragma unroll 4 // Already tuned for speed
         | 
| 643 | 
            -
                    for (int s = 0; s <= thr_max_valid_split; ++s) {
         | 
| 644 | 
            -
                        Tensor scale = make_tensor<float>(make_shape(size<1>(tOrOpartial)));
         | 
| 645 | 
            -
                        #pragma unroll
         | 
| 646 | 
            -
                        for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); }
         | 
| 647 | 
            -
             | 
| 648 | 
            -
                        if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); }
         | 
| 649 | 
            -
                        if constexpr (Has_cp_async) { cute::cp_async_fence(); }
         | 
| 650 | 
            -
                        stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0;
         | 
| 651 | 
            -
                        if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
         | 
| 652 | 
            -
                        // We don't need __syncthreads() because each thread is just reading its own data from smem
         | 
| 653 | 
            -
                        cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>{},
         | 
| 654 | 
            -
                                   tOsOpartial(_, _, _, stage_compute), tOrOpartial);
         | 
| 655 | 
            -
                        stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0;
         | 
| 656 | 
            -
             | 
| 657 | 
            -
                        #pragma unroll
         | 
| 658 | 
            -
                        for (int m = 0; m < size<1>(tOrOpartial); ++m) {
         | 
| 659 | 
            -
                            if (tObidh(m) >= 0 && scale(m) > 0.f) {
         | 
| 660 | 
            -
                                #pragma unroll
         | 
| 661 | 
            -
                                for (int k = 0; k < size<2>(tOrOpartial); ++k) {
         | 
| 662 | 
            -
                                    if (Is_even_K || tOpO(k)) {
         | 
| 663 | 
            -
                                        Tensor rOpartial = make_tensor_like<float>(tOrOpartial(_, m, k));
         | 
| 664 | 
            -
                                        flash::convert_type_out(tOrOpartial(_, m, k), rOpartial);
         | 
| 665 | 
            -
                                        #pragma unroll
         | 
| 666 | 
            -
                                        for (int i = 0; i < size<0>(tOrOpartial); ++i) {
         | 
| 667 | 
            -
                                            tOrO(i, m, k) += scale(m) * rOpartial[i];
         | 
| 668 | 
            -
                                        }
         | 
| 669 | 
            -
                                    }
         | 
| 670 | 
            -
                                }
         | 
| 671 | 
            -
                            }
         | 
| 672 | 
            -
                        }
         | 
| 673 | 
            -
                    }
         | 
| 674 | 
            -
             | 
| 675 | 
            -
                    // Step 7: Write the final O to gmem
         | 
| 676 | 
            -
                    Tensor rO = make_tensor_like<Element>(tOrO);
         | 
| 677 | 
            -
                    flash::convert_type_out(tOrO, rO);
         | 
| 678 | 
            -
                    auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial));
         | 
| 679 | 
            -
                    Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)),
         | 
| 680 | 
            -
                                            shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0);
         | 
| 681 | 
            -
                    Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int<kGmemElemsPerLoad>>{});
         | 
| 682 | 
            -
                    GmemTiledCopy gmem_tiled_copy_O;
         | 
| 683 | 
            -
                    auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
         | 
| 684 | 
            -
             | 
| 685 | 
            -
                    #pragma unroll
         | 
| 686 | 
            -
                    for (int m = 0; m < size<1>(tOcO); ++m) {
         | 
| 687 | 
            -
                        if (tObidh(m) >= 0)  {
         | 
| 688 | 
            -
                            #pragma unroll
         | 
| 689 | 
            -
                            for (int k = 0; k < size<2>(tOcO); ++k) {
         | 
| 690 | 
            -
                                int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
         | 
| 691 | 
            -
                                if (Is_even_K || tOpO(k)) {
         | 
| 692 | 
            -
                                    cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m)));
         | 
| 693 | 
            -
                                }
         | 
| 694 | 
            -
                            }
         | 
| 695 | 
            -
                        }
         | 
| 696 | 
            -
                    }
         | 
| 697 | 
            -
             | 
| 698 | 
            -
                }
         | 
| 699 | 
            -
             | 
| 700 | 
            -
            };
         | 
| 701 | 
            -
             | 
| 702 | 
            -
            } // namespace flash
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash_fwd_combine_launch_template.h
    DELETED
    
    | @@ -1,88 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include "cute/tensor.hpp"
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            #include "cutlass/cutlass.h"
         | 
| 10 | 
            -
            #include "cutlass/arch/arch.h"  // For cutlass::arch::Sm80
         | 
| 11 | 
            -
            #include "cutlass/device_kernel.h"  // For device_kernel
         | 
| 12 | 
            -
            #include "cutlass/kernel_launch.h"  // For kernel_launch
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            #include "static_switch.h"
         | 
| 15 | 
            -
            #include "flash.h"
         | 
| 16 | 
            -
            #include "flash_fwd_combine_kernel.h"
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            using namespace cute;
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            template <int Arch, int kBlockM, int kBlockK, int kLogMaxSplits, bool IsEvenK, bool Varlen, typename Element, typename ElementPartial>
         | 
| 21 | 
            -
            void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) {
         | 
| 22 | 
            -
                using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
         | 
| 23 | 
            -
                using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kBlockK>>;
         | 
| 24 | 
            -
                using CombineKernel = flash::FlashAttnFwdCombine<TileShape_MK, kLogMaxSplits, 256 /*kNThreads*/, 1 /*AlignmentLSE*/,
         | 
| 25 | 
            -
                                                                 IsEvenK, Varlen, Element, ElementPartial, ArchTag>;
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                typename CombineKernel::Arguments args {
         | 
| 28 | 
            -
                    params.b,
         | 
| 29 | 
            -
                    static_cast<ElementPartial const*>(params.oaccum_ptr),
         | 
| 30 | 
            -
                    {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1},  // shape_O_partial
         | 
| 31 | 
            -
                    {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0},  // stride_O_partial
         | 
| 32 | 
            -
                    static_cast<float*>(params.softmax_lseaccum_ptr),
         | 
| 33 | 
            -
                    {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1},  // shape_LSE_partial
         | 
| 34 | 
            -
                    {_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0},  // stride_LSE_partial
         | 
| 35 | 
            -
                    static_cast<Element*>(params.o_ptr),
         | 
| 36 | 
            -
                    {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0},  // stride_O
         | 
| 37 | 
            -
                    static_cast<float*>(params.softmax_lse_ptr),
         | 
| 38 | 
            -
                    {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0},  // stride_LSE
         | 
| 39 | 
            -
                    params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore
         | 
| 40 | 
            -
                };
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                typename CombineKernel::SchedulerArguments scheduler_args  {
         | 
| 43 | 
            -
                    params.b, params.seqlen_q, params.total_q, params.h, params.dv,
         | 
| 44 | 
            -
                    params.cu_seqlens_q, params.seqused_q
         | 
| 45 | 
            -
                };
         | 
| 46 | 
            -
             | 
| 47 | 
            -
                typename CombineKernel::Params kernel_params = {
         | 
| 48 | 
            -
                    CombineKernel::to_underlying_arguments(args),
         | 
| 49 | 
            -
                    CombineKernel::TileScheduler::to_underlying_arguments(scheduler_args)
         | 
| 50 | 
            -
                };
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                dim3 grid_m = CombineKernel::TileScheduler::get_grid_shape(scheduler_args);
         | 
| 53 | 
            -
                auto kernel = cutlass::device_kernel<CombineKernel>;
         | 
| 54 | 
            -
                int smem_size = CombineKernel::SharedStorageSize;
         | 
| 55 | 
            -
                if (smem_size >= 48 * 1024) {
         | 
| 56 | 
            -
                    CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
         | 
| 57 | 
            -
                }
         | 
| 58 | 
            -
                // kernel<<<grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream>>>(kernel_params);
         | 
| 59 | 
            -
                cutlass::kernel_launch<CombineKernel>(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/);
         | 
| 60 | 
            -
                CHECK_CUDA_KERNEL_LAUNCH();
         | 
| 61 | 
            -
            }
         | 
| 62 | 
            -
             | 
| 63 | 
            -
            template<typename T, typename Tpartial, int kBlockK>
         | 
| 64 | 
            -
            void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) {
         | 
| 65 | 
            -
                // We want kBlockM to be as small as possible to maximize parallelism.
         | 
| 66 | 
            -
                // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
         | 
| 67 | 
            -
                static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32");
         | 
| 68 | 
            -
                static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32);
         | 
| 69 | 
            -
                ARCH_SWITCH(params.arch, Arch, [&] {
         | 
| 70 | 
            -
                    BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] {
         | 
| 71 | 
            -
                        if constexpr (kBlockM >= 16) {  // If kBlockM == 8 then the minimum number of splits is 32.
         | 
| 72 | 
            -
                            if (params.num_splits <= 16) {
         | 
| 73 | 
            -
                                run_flash_fwd_combine<Arch, kBlockM, kBlockK, 4, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
         | 
| 74 | 
            -
                                return;
         | 
| 75 | 
            -
                            }
         | 
| 76 | 
            -
                        }
         | 
| 77 | 
            -
                        if (params.num_splits <= 32) {
         | 
| 78 | 
            -
                            run_flash_fwd_combine<Arch, kBlockM, kBlockK, 5, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
         | 
| 79 | 
            -
                        } else if (params.num_splits <= 64) {
         | 
| 80 | 
            -
                            run_flash_fwd_combine<Arch, kBlockM, kBlockK, 6, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
         | 
| 81 | 
            -
                        } else if (params.num_splits <= 128) {
         | 
| 82 | 
            -
                            run_flash_fwd_combine<Arch, kBlockM, kBlockK, 7, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
         | 
| 83 | 
            -
                        } else {
         | 
| 84 | 
            -
                            run_flash_fwd_combine<Arch, kBlockM, kBlockK, 8, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
         | 
| 85 | 
            -
                        }
         | 
| 86 | 
            -
                    });
         | 
| 87 | 
            -
                });
         | 
| 88 | 
            -
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash_fwd_kernel_sm80.h
    DELETED
    
    | @@ -1,215 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include "cute/tensor.hpp"
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            #include <cutlass/cutlass.h>
         | 
| 10 | 
            -
            #include <cutlass/array.h>
         | 
| 11 | 
            -
            #include <cutlass/numeric_types.h>
         | 
| 12 | 
            -
            #include <cutlass/kernel_hardware_info.h>
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            #include "seqlen.h"
         | 
| 15 | 
            -
            #include "utils.h"
         | 
| 16 | 
            -
            #include "softmax.h"
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            namespace flash {
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            using namespace cute;
         | 
| 21 | 
            -
             | 
| 22 | 
            -
            template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
         | 
| 23 | 
            -
            class FlashAttnFwdSm80 {
         | 
| 24 | 
            -
             | 
| 25 | 
            -
            public:
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                // Type Aliases
         | 
| 28 | 
            -
                using CollectiveMainloop = CollectiveMainloop_;
         | 
| 29 | 
            -
                using CollectiveEpilogue = CollectiveEpilogue_;
         | 
| 30 | 
            -
                static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
         | 
| 31 | 
            -
                static constexpr bool Is_local = CollectiveMainloop::Is_local;
         | 
| 32 | 
            -
                static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
         | 
| 33 | 
            -
                static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
         | 
| 34 | 
            -
                static constexpr bool Varlen = CollectiveMainloop::Varlen;
         | 
| 35 | 
            -
                static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
         | 
| 36 | 
            -
                static constexpr bool Split = CollectiveMainloop::Split;
         | 
| 37 | 
            -
                static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
         | 
| 38 | 
            -
                static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
         | 
| 39 | 
            -
                static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
         | 
| 40 | 
            -
                static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
         | 
| 41 | 
            -
                static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
         | 
| 42 | 
            -
                using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
         | 
| 43 | 
            -
             | 
| 44 | 
            -
                // Mainloop derived types
         | 
| 45 | 
            -
                using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
         | 
| 46 | 
            -
                using TiledMma = typename CollectiveMainloop::TiledMma;
         | 
| 47 | 
            -
                using ArchTag = typename CollectiveMainloop::ArchTag;
         | 
| 48 | 
            -
                using MainloopArguments = typename CollectiveMainloop::Arguments;
         | 
| 49 | 
            -
                using MainloopParams = typename CollectiveMainloop::Params;
         | 
| 50 | 
            -
             | 
| 51 | 
            -
                // Epilogue derived types
         | 
| 52 | 
            -
                using EpilogueArguments = typename CollectiveEpilogue::Arguments;
         | 
| 53 | 
            -
                using EpilogueParams = typename CollectiveEpilogue::Params;
         | 
| 54 | 
            -
             | 
| 55 | 
            -
                static_assert(ArchTag::kMinComputeCapability >= 80);
         | 
| 56 | 
            -
             | 
| 57 | 
            -
                using TileScheduler = TileScheduler_;
         | 
| 58 | 
            -
                using TileSchedulerArguments = typename flash::TileSchedulerArguments;
         | 
| 59 | 
            -
                using TileSchedulerParams = typename TileScheduler::Params;
         | 
| 60 | 
            -
             | 
| 61 | 
            -
                static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{}));
         | 
| 62 | 
            -
                static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{}));
         | 
| 63 | 
            -
                static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1;
         | 
| 64 | 
            -
             | 
| 65 | 
            -
                // Kernel level shared memory storage
         | 
| 66 | 
            -
                // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + smem_k and not smem_q
         | 
| 67 | 
            -
                // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v) + sizeof(smem_k).
         | 
| 68 | 
            -
                static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage))
         | 
| 69 | 
            -
                    - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)))
         | 
| 70 | 
            -
                    - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)));
         | 
| 71 | 
            -
                static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
         | 
| 72 | 
            -
                struct SharedStorage {
         | 
| 73 | 
            -
                    struct TensorStorage : cute::aligned_struct<128> {
         | 
| 74 | 
            -
                        union {
         | 
| 75 | 
            -
                            struct {
         | 
| 76 | 
            -
                                cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
         | 
| 77 | 
            -
                                typename CollectiveMainloop::TensorStorage mainloop;
         | 
| 78 | 
            -
                            };
         | 
| 79 | 
            -
                            // We want smem_o to line up with the start of smem_v
         | 
| 80 | 
            -
                            typename CollectiveEpilogue::TensorStorage epilogue;
         | 
| 81 | 
            -
                        };
         | 
| 82 | 
            -
                    } tensors;
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                    alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
         | 
| 85 | 
            -
             | 
| 86 | 
            -
                };
         | 
| 87 | 
            -
             | 
| 88 | 
            -
                static constexpr int SharedStorageSize = sizeof(SharedStorage);
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                // Device side arguments
         | 
| 91 | 
            -
                struct Arguments {
         | 
| 92 | 
            -
                    MainloopArguments mainloop{};
         | 
| 93 | 
            -
                    EpilogueArguments epilogue{};
         | 
| 94 | 
            -
                    cutlass::KernelHardwareInfo hw_info{};
         | 
| 95 | 
            -
                    TileSchedulerArguments scheduler{};
         | 
| 96 | 
            -
                };
         | 
| 97 | 
            -
             | 
| 98 | 
            -
                // Kernel entry point API
         | 
| 99 | 
            -
                struct Params {
         | 
| 100 | 
            -
                    MainloopParams mainloop{};
         | 
| 101 | 
            -
                    EpilogueParams epilogue{};
         | 
| 102 | 
            -
                    cutlass::KernelHardwareInfo hw_info{};
         | 
| 103 | 
            -
                    TileSchedulerParams scheduler{};
         | 
| 104 | 
            -
                };
         | 
| 105 | 
            -
             | 
| 106 | 
            -
                //
         | 
| 107 | 
            -
                // Methods
         | 
| 108 | 
            -
                //
         | 
| 109 | 
            -
             | 
| 110 | 
            -
                // Convert to underlying arguments. In this case, a simple copy for the aliased type.
         | 
| 111 | 
            -
                static
         | 
| 112 | 
            -
                Params
         | 
| 113 | 
            -
                to_underlying_arguments(Arguments const& args) {
         | 
| 114 | 
            -
                    CUTLASS_TRACE_HOST("to_underlying_arguments():");
         | 
| 115 | 
            -
             | 
| 116 | 
            -
                    // Get SM count if needed, otherwise use user supplied SM count
         | 
| 117 | 
            -
                    int sm_count = args.hw_info.sm_count;
         | 
| 118 | 
            -
                    if (sm_count <= 0) {
         | 
| 119 | 
            -
                        CUTLASS_TRACE_HOST("  WARNING: Arguments do not include a valid SM count.\n"
         | 
| 120 | 
            -
                            "  For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
         | 
| 121 | 
            -
                        sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
         | 
| 122 | 
            -
                    }
         | 
| 123 | 
            -
             | 
| 124 | 
            -
                    CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
         | 
| 125 | 
            -
             | 
| 126 | 
            -
                    cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
         | 
| 127 | 
            -
                    return {
         | 
| 128 | 
            -
                        CollectiveMainloop::to_underlying_arguments(args.mainloop),
         | 
| 129 | 
            -
                        CollectiveEpilogue::to_underlying_arguments(args.epilogue),
         | 
| 130 | 
            -
                        hw_info,
         | 
| 131 | 
            -
                        TileScheduler::to_underlying_arguments(args.scheduler)
         | 
| 132 | 
            -
                    };
         | 
| 133 | 
            -
                }
         | 
| 134 | 
            -
             | 
| 135 | 
            -
                // Computes the kernel launch grid shape based on runtime parameters
         | 
| 136 | 
            -
                static dim3
         | 
| 137 | 
            -
                get_grid_shape(Params const& params) {
         | 
| 138 | 
            -
                    return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor);
         | 
| 139 | 
            -
                }
         | 
| 140 | 
            -
             | 
| 141 | 
            -
                static dim3
         | 
| 142 | 
            -
                get_block_shape() {
         | 
| 143 | 
            -
                    return dim3(MaxThreadsPerBlock, 1, 1);
         | 
| 144 | 
            -
                }
         | 
| 145 | 
            -
             | 
| 146 | 
            -
                CUTLASS_DEVICE
         | 
| 147 | 
            -
                void
         | 
| 148 | 
            -
                operator()(Params const& params, char* smem_buf) {
         | 
| 149 | 
            -
             | 
| 150 | 
            -
                    static constexpr int kBlockM = get<0>(TileShape_MNK{});
         | 
| 151 | 
            -
             | 
| 152 | 
            -
                    SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
         | 
| 153 | 
            -
             | 
| 154 | 
            -
                    CollectiveMainloop mainloop;
         | 
| 155 | 
            -
                    CollectiveEpilogue epilogue;
         | 
| 156 | 
            -
             | 
| 157 | 
            -
                    TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
         | 
| 158 | 
            -
                    // Initialize matmul objects.
         | 
| 159 | 
            -
                    TiledMma tiled_mma;
         | 
| 160 | 
            -
             | 
| 161 | 
            -
                    scheduler.init_consumer();
         | 
| 162 | 
            -
             | 
| 163 | 
            -
                    int warp_idx = cutlass::canonical_warp_idx_sync();
         | 
| 164 | 
            -
                    CUTLASS_PRAGMA_NO_UNROLL
         | 
| 165 | 
            -
                    for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
         | 
| 166 | 
            -
                         work_tile_info.is_valid(params.scheduler);
         | 
| 167 | 
            -
                         work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
         | 
| 168 | 
            -
                        // Attention output (GEMM-II) accumulator.
         | 
| 169 | 
            -
                        Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{}));
         | 
| 170 | 
            -
                        float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
         | 
| 171 | 
            -
                        // If there's tanh softcap, the scaling will be done before tanh.
         | 
| 172 | 
            -
                        auto block_coord = work_tile_info.get_block_coord(params.scheduler);
         | 
| 173 | 
            -
                        int const bidb = get<2>(block_coord);
         | 
| 174 | 
            -
                        if constexpr (Is_FP8 && !Has_softcap) {
         | 
| 175 | 
            -
                            int const bidh = get<1>(block_coord);
         | 
| 176 | 
            -
                            int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
         | 
| 177 | 
            -
                            float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
         | 
| 178 | 
            -
                            float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
         | 
| 179 | 
            -
                            softmax_scale_log2 *= q_descale * k_descale;
         | 
| 180 | 
            -
                        }
         | 
| 181 | 
            -
                        flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
         | 
| 182 | 
            -
             | 
| 183 | 
            -
                        SeqlenInfo_t seqlen_info{
         | 
| 184 | 
            -
                            bidb,
         | 
| 185 | 
            -
                            get<0>(params.mainloop.shape_Q),
         | 
| 186 | 
            -
                            !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
         | 
| 187 | 
            -
                            get<0>(params.mainloop.shape_K_new),
         | 
| 188 | 
            -
                            params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
         | 
| 189 | 
            -
                            params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
         | 
| 190 | 
            -
                            params.mainloop.seqlens_rotary
         | 
| 191 | 
            -
                        };
         | 
| 192 | 
            -
                        if constexpr (AppendKV) {
         | 
| 193 | 
            -
                            bool tile_new_valid = mainloop.store_kv_new(
         | 
| 194 | 
            -
                                params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord);
         | 
| 195 | 
            -
                            if (tile_new_valid) { __syncthreads(); }
         | 
| 196 | 
            -
                        }
         | 
| 197 | 
            -
                        bool tile_valid = mainloop.mma(
         | 
| 198 | 
            -
                            params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord,
         | 
| 199 | 
            -
                            shared_storage);
         | 
| 200 | 
            -
                        scheduler.prefetch_next_work(params.scheduler, work_tile_info);
         | 
| 201 | 
            -
                        if (tile_valid) {
         | 
| 202 | 
            -
                            // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
         | 
| 203 | 
            -
                            epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma,
         | 
| 204 | 
            -
                                           threadIdx.x, block_coord);
         | 
| 205 | 
            -
                        } else {
         | 
| 206 | 
            -
                            // Write 0 to gO and -inf to gLSE.
         | 
| 207 | 
            -
                            epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
         | 
| 208 | 
            -
                        }
         | 
| 209 | 
            -
                    }
         | 
| 210 | 
            -
             | 
| 211 | 
            -
                }
         | 
| 212 | 
            -
             | 
| 213 | 
            -
            };
         | 
| 214 | 
            -
             | 
| 215 | 
            -
            } // namespace flash
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash_fwd_kernel_sm90.h
    DELETED
    
    | @@ -1,468 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include "cute/tensor.hpp"
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            #include <cutlass/cutlass.h>
         | 
| 10 | 
            -
            #include <cutlass/arch/reg_reconfig.h>
         | 
| 11 | 
            -
            #include <cutlass/array.h>
         | 
| 12 | 
            -
            #include <cutlass/numeric_types.h>
         | 
| 13 | 
            -
            #include <cutlass/numeric_conversion.h>
         | 
| 14 | 
            -
            #include <cutlass/kernel_hardware_info.h>
         | 
| 15 | 
            -
            #include "cutlass/pipeline/pipeline.hpp"
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            #include "cutlass/arch/grid_dependency_control.h"
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            #include "seqlen.h"
         | 
| 20 | 
            -
            #include "utils.h"
         | 
| 21 | 
            -
            #include "softmax.h"
         | 
| 22 | 
            -
             | 
| 23 | 
            -
            namespace flash {
         | 
| 24 | 
            -
             | 
| 25 | 
            -
            using namespace cute;
         | 
| 26 | 
            -
             | 
| 27 | 
            -
            template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
         | 
| 28 | 
            -
            class FlashAttnFwdSm90 {
         | 
| 29 | 
            -
             | 
| 30 | 
            -
            public:
         | 
| 31 | 
            -
             | 
| 32 | 
            -
                // Type Aliases
         | 
| 33 | 
            -
                using CollectiveMainloop = CollectiveMainloop_;
         | 
| 34 | 
            -
                using CollectiveEpilogue = CollectiveEpilogue_;
         | 
| 35 | 
            -
                static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
         | 
| 36 | 
            -
                static constexpr bool Is_local = CollectiveMainloop::Is_local;
         | 
| 37 | 
            -
                static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
         | 
| 38 | 
            -
                static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
         | 
| 39 | 
            -
                static constexpr bool Varlen = CollectiveMainloop::Varlen;
         | 
| 40 | 
            -
                static constexpr bool Split = CollectiveMainloop::Split;
         | 
| 41 | 
            -
                static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
         | 
| 42 | 
            -
                static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
         | 
| 43 | 
            -
                static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
         | 
| 44 | 
            -
                static constexpr bool HasQv = CollectiveMainloop::HasQv;
         | 
| 45 | 
            -
                static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q;
         | 
| 46 | 
            -
                static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV;
         | 
| 47 | 
            -
                static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O;
         | 
| 48 | 
            -
                static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
         | 
| 49 | 
            -
                static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
         | 
| 50 | 
            -
                static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim;
         | 
| 51 | 
            -
                static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV;
         | 
| 52 | 
            -
                static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV);
         | 
| 53 | 
            -
                using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
         | 
| 54 | 
            -
             | 
| 55 | 
            -
                using SmemLayoutSAux = typename CollectiveMainloop::SmemLayoutSAux;
         | 
| 56 | 
            -
             | 
| 57 | 
            -
                // Mainloop derived types
         | 
| 58 | 
            -
                using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV;
         | 
| 59 | 
            -
                using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV;
         | 
| 60 | 
            -
                using ArchTag = typename CollectiveMainloop::ArchTag;
         | 
| 61 | 
            -
                using ClusterShape = typename CollectiveMainloop::ClusterShape;
         | 
| 62 | 
            -
                using MainloopArguments = typename CollectiveMainloop::Arguments;
         | 
| 63 | 
            -
                using MainloopParams = typename CollectiveMainloop::Params;
         | 
| 64 | 
            -
                using BarrierQ = std::conditional_t<Use_TMA_Q, cutlass::arch::ClusterTransactionBarrier, cutlass::arch::ClusterBarrier>;
         | 
| 65 | 
            -
             | 
| 66 | 
            -
                // Epilogue derived types
         | 
| 67 | 
            -
                using EpilogueArguments = typename CollectiveEpilogue::Arguments;
         | 
| 68 | 
            -
                using EpilogueParams = typename CollectiveEpilogue::Params;
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                static_assert(ArchTag::kMinComputeCapability >= 90);
         | 
| 71 | 
            -
             | 
| 72 | 
            -
                using TileScheduler = TileScheduler_;
         | 
| 73 | 
            -
                using TileSchedulerArguments = typename flash::TileSchedulerArguments;
         | 
| 74 | 
            -
                using TileSchedulerParams = typename TileScheduler::Params;
         | 
| 75 | 
            -
             | 
| 76 | 
            -
                static constexpr uint32_t NumLoadWarpGroups = 1;
         | 
| 77 | 
            -
                static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup;
         | 
| 78 | 
            -
                static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
         | 
| 79 | 
            -
                static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
         | 
| 80 | 
            -
                static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
         | 
| 81 | 
            -
             | 
| 82 | 
            -
                /// Register requirement for Load and Math WGs
         | 
| 83 | 
            -
                // If we use cp.async to load K and V, we need more registers for the producer WG.
         | 
| 84 | 
            -
                static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32);
         | 
| 85 | 
            -
                static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160);
         | 
| 86 | 
            -
                // If you want to print from the producer warp, you'd need to increase the number of registers
         | 
| 87 | 
            -
                // Otherwise you'll get CUDA error.
         | 
| 88 | 
            -
                // static constexpr uint32_t LoadRegisterRequirement = 40;
         | 
| 89 | 
            -
                // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                // Kernel level shared memory storage
         | 
| 92 | 
            -
                // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v
         | 
| 93 | 
            -
                // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v).
         | 
| 94 | 
            -
                static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)));
         | 
| 95 | 
            -
                static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
         | 
| 96 | 
            -
                struct SharedStorage {
         | 
| 97 | 
            -
                    struct TensorStorage : cute::aligned_struct<128, _1> {
         | 
| 98 | 
            -
                        union {
         | 
| 99 | 
            -
                            struct {
         | 
| 100 | 
            -
                                cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
         | 
| 101 | 
            -
                                typename CollectiveMainloop::TensorStorage mainloop;
         | 
| 102 | 
            -
                            };
         | 
| 103 | 
            -
                            // We want smem_o to line up with the start of smem_v
         | 
| 104 | 
            -
                            typename CollectiveEpilogue::TensorStorage epilogue;
         | 
| 105 | 
            -
                        };
         | 
| 106 | 
            -
                    } tensors;
         | 
| 107 | 
            -
                    struct PipelineStorage : cute::aligned_struct<16, _1> {
         | 
| 108 | 
            -
                        alignas(16) BarrierQ barrier_Q;
         | 
| 109 | 
            -
                        alignas(16) BarrierQ barrier_Qv;
         | 
| 110 | 
            -
                        alignas(16) cutlass::arch::ClusterBarrier barrier_O;
         | 
| 111 | 
            -
                        alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k;
         | 
| 112 | 
            -
                        alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v;
         | 
| 113 | 
            -
                        alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt;
         | 
| 114 | 
            -
                        alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new;
         | 
| 115 | 
            -
                        alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new;
         | 
| 116 | 
            -
                        alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
         | 
| 117 | 
            -
                    } pipelines;
         | 
| 118 | 
            -
             | 
| 119 | 
            -
                };
         | 
| 120 | 
            -
             | 
| 121 | 
            -
                static constexpr int SharedStorageSize = sizeof(SharedStorage);
         | 
| 122 | 
            -
             | 
| 123 | 
            -
                // Device side arguments
         | 
| 124 | 
            -
                struct Arguments {
         | 
| 125 | 
            -
                    MainloopArguments mainloop{};
         | 
| 126 | 
            -
                    EpilogueArguments epilogue{};
         | 
| 127 | 
            -
                    cutlass::KernelHardwareInfo hw_info{};
         | 
| 128 | 
            -
                    TileSchedulerArguments scheduler{};
         | 
| 129 | 
            -
                };
         | 
| 130 | 
            -
             | 
| 131 | 
            -
                // Kernel entry point API
         | 
| 132 | 
            -
                struct Params {
         | 
| 133 | 
            -
                    MainloopParams mainloop{};
         | 
| 134 | 
            -
                    EpilogueParams epilogue{};
         | 
| 135 | 
            -
                    cutlass::KernelHardwareInfo hw_info{};
         | 
| 136 | 
            -
                    TileSchedulerParams scheduler{};
         | 
| 137 | 
            -
                };
         | 
| 138 | 
            -
             | 
| 139 | 
            -
                //
         | 
| 140 | 
            -
                // Methods
         | 
| 141 | 
            -
                //
         | 
| 142 | 
            -
             | 
| 143 | 
            -
                // Convert to underlying arguments. In this case, a simple copy for the aliased type.
         | 
| 144 | 
            -
                static
         | 
| 145 | 
            -
                Params
         | 
| 146 | 
            -
                to_underlying_arguments(Arguments const& args) {
         | 
| 147 | 
            -
                    CUTLASS_TRACE_HOST("to_underlying_arguments():");
         | 
| 148 | 
            -
             | 
| 149 | 
            -
                    // Get SM count if needed, otherwise use user supplied SM count
         | 
| 150 | 
            -
                    int sm_count = args.hw_info.sm_count;
         | 
| 151 | 
            -
                    if (sm_count <= 0) {
         | 
| 152 | 
            -
                        CUTLASS_TRACE_HOST("  WARNING: Arguments do not include a valid SM count.\n"
         | 
| 153 | 
            -
                            "  For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
         | 
| 154 | 
            -
                        sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
         | 
| 155 | 
            -
                    }
         | 
| 156 | 
            -
             | 
| 157 | 
            -
                    CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
         | 
| 158 | 
            -
             | 
| 159 | 
            -
                    cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
         | 
| 160 | 
            -
                    return {
         | 
| 161 | 
            -
                        CollectiveMainloop::to_underlying_arguments(args.mainloop),
         | 
| 162 | 
            -
                        CollectiveEpilogue::to_underlying_arguments(args.epilogue),
         | 
| 163 | 
            -
                        hw_info,
         | 
| 164 | 
            -
                        TileScheduler::to_underlying_arguments(args.scheduler)
         | 
| 165 | 
            -
                    };
         | 
| 166 | 
            -
                }
         | 
| 167 | 
            -
             | 
| 168 | 
            -
                // Computes the kernel launch grid shape based on runtime parameters
         | 
| 169 | 
            -
                static dim3
         | 
| 170 | 
            -
                get_grid_shape(Params const& params) {
         | 
| 171 | 
            -
                    return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
         | 
| 172 | 
            -
                }
         | 
| 173 | 
            -
             | 
| 174 | 
            -
                static dim3
         | 
| 175 | 
            -
                get_block_shape() {
         | 
| 176 | 
            -
                    return dim3(MaxThreadsPerBlock, 1, 1);
         | 
| 177 | 
            -
                }
         | 
| 178 | 
            -
             | 
| 179 | 
            -
                CUTLASS_DEVICE
         | 
| 180 | 
            -
                void
         | 
| 181 | 
            -
                operator()(Params const& params, char* smem_buf) {
         | 
| 182 | 
            -
             | 
| 183 | 
            -
                    static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
         | 
| 184 | 
            -
                    static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
         | 
| 185 | 
            -
                    static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
         | 
| 186 | 
            -
             | 
| 187 | 
            -
                    using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK;
         | 
| 188 | 
            -
                    using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV;
         | 
| 189 | 
            -
                    using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt;
         | 
| 190 | 
            -
                    using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew;
         | 
| 191 | 
            -
                    using PipelineState = typename CollectiveMainloop::PipelineState;
         | 
| 192 | 
            -
                    using PipelineParamsK = typename MainloopPipelineK::Params;
         | 
| 193 | 
            -
                    using PipelineParamsV = typename MainloopPipelineV::Params;
         | 
| 194 | 
            -
                    using PipelineParamsVt = typename MainloopPipelineVt::Params;
         | 
| 195 | 
            -
                    using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params;
         | 
| 196 | 
            -
             | 
| 197 | 
            -
                    SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
         | 
| 198 | 
            -
             | 
| 199 | 
            -
                    int const lane_predicate = cute::elect_one_sync();
         | 
| 200 | 
            -
                    int const warp_idx = cutlass::canonical_warp_idx_sync();
         | 
| 201 | 
            -
             | 
| 202 | 
            -
                    // Issue Tma Descriptor Prefetch from a single thread
         | 
| 203 | 
            -
                    if (warp_idx == 0 && lane_predicate) {
         | 
| 204 | 
            -
                        CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
         | 
| 205 | 
            -
                        CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
         | 
| 206 | 
            -
                    }
         | 
| 207 | 
            -
             | 
| 208 | 
            -
                    // Obtain warp index
         | 
| 209 | 
            -
                    int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
         | 
| 210 | 
            -
                    int warp_group_idx = cutlass::canonical_warp_group_idx();
         | 
| 211 | 
            -
             | 
| 212 | 
            -
                    if (warp_idx == 0 && lane_predicate) {
         | 
| 213 | 
            -
                        shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
         | 
| 214 | 
            -
                        if constexpr (HasQv) {
         | 
| 215 | 
            -
                            shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
         | 
| 216 | 
            -
                        }
         | 
| 217 | 
            -
                        shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/);
         | 
| 218 | 
            -
                    }
         | 
| 219 | 
            -
             | 
| 220 | 
            -
                    // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
         | 
| 221 | 
            -
                    PipelineParamsK pipeline_params_k;
         | 
| 222 | 
            -
                    pipeline_params_k.role = warp_group_idx == 0
         | 
| 223 | 
            -
                        ? MainloopPipelineK::ThreadCategory::Producer
         | 
| 224 | 
            -
                        : MainloopPipelineK::ThreadCategory::Consumer;
         | 
| 225 | 
            -
                    if constexpr (Use_TMA_KV) {
         | 
| 226 | 
            -
                        pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
         | 
| 227 | 
            -
                        pipeline_params_k.is_leader = warp_group_thread_idx == 0;
         | 
| 228 | 
            -
                        pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
         | 
| 229 | 
            -
                    } else {
         | 
| 230 | 
            -
                        pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
         | 
| 231 | 
            -
                        pipeline_params_k.producer_arv_count = NumProducerThreads;
         | 
| 232 | 
            -
                    }
         | 
| 233 | 
            -
             | 
| 234 | 
            -
                    static_assert(is_same_v<PipelineParamsK, PipelineParamsVt>);
         | 
| 235 | 
            -
                    PipelineParamsVt pipeline_params_vt = pipeline_params_k;
         | 
| 236 | 
            -
                    if constexpr (Use_TMA_KV && !SameHeadDim) {
         | 
| 237 | 
            -
                        pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
         | 
| 238 | 
            -
                        if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; }
         | 
| 239 | 
            -
                    } else {
         | 
| 240 | 
            -
                        if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; }
         | 
| 241 | 
            -
                    }
         | 
| 242 | 
            -
             | 
| 243 | 
            -
                    MainloopPipelineK pipeline_k = [&] {
         | 
| 244 | 
            -
                        if constexpr (Use_TMA_KV) {
         | 
| 245 | 
            -
                            return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{});
         | 
| 246 | 
            -
                        } else {
         | 
| 247 | 
            -
                            return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k);
         | 
| 248 | 
            -
                        }
         | 
| 249 | 
            -
                    }();
         | 
| 250 | 
            -
                    // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{});
         | 
| 251 | 
            -
                    MainloopPipelineV pipeline_v = [&] {
         | 
| 252 | 
            -
                        if constexpr (!Transpose_V) {
         | 
| 253 | 
            -
                            static_assert(is_same_v<PipelineParamsK, PipelineParamsV>);
         | 
| 254 | 
            -
                            if constexpr (Use_TMA_KV) {
         | 
| 255 | 
            -
                                return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{});
         | 
| 256 | 
            -
                            } else {
         | 
| 257 | 
            -
                                return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt);
         | 
| 258 | 
            -
                            }
         | 
| 259 | 
            -
                        } else {
         | 
| 260 | 
            -
                            PipelineParamsV pipeline_params_v;
         | 
| 261 | 
            -
                            pipeline_params_v.role = warp_group_idx == 0
         | 
| 262 | 
            -
                                ? MainloopPipelineV::ThreadCategory::Producer
         | 
| 263 | 
            -
                                : MainloopPipelineV::ThreadCategory::Consumer;
         | 
| 264 | 
            -
                            pipeline_params_v.producer_arv_count = NumProducerThreads;
         | 
| 265 | 
            -
                            pipeline_params_v.consumer_arv_count = NumMmaThreads;
         | 
| 266 | 
            -
                            return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v);
         | 
| 267 | 
            -
                        }
         | 
| 268 | 
            -
                    }();
         | 
| 269 | 
            -
                    // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then
         | 
| 270 | 
            -
                    // the producer WG will read from pipeline_vt and write to pipeline_v.
         | 
| 271 | 
            -
                    // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used.
         | 
| 272 | 
            -
                    // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers.
         | 
| 273 | 
            -
                    // However, the thread role isn't used in the pipeline implementation.
         | 
| 274 | 
            -
                    MainloopPipelineVt pipeline_vt = [&] {
         | 
| 275 | 
            -
                        if constexpr (Use_TMA_KV) {
         | 
| 276 | 
            -
                            pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG
         | 
| 277 | 
            -
                            return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{});
         | 
| 278 | 
            -
                        } else {
         | 
| 279 | 
            -
                            pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG
         | 
| 280 | 
            -
                            return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt);
         | 
| 281 | 
            -
                        }
         | 
| 282 | 
            -
                    }();
         | 
| 283 | 
            -
             | 
| 284 | 
            -
                    PipelineParamsKVNew pipeline_params_kv_new;
         | 
| 285 | 
            -
                    pipeline_params_kv_new.role = warp_group_idx == 0
         | 
| 286 | 
            -
                        ? MainloopPipelineKVNew::ThreadCategory::Producer
         | 
| 287 | 
            -
                        : MainloopPipelineKVNew::ThreadCategory::Consumer;
         | 
| 288 | 
            -
                    pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
         | 
| 289 | 
            -
                    pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0;
         | 
| 290 | 
            -
                    pipeline_params_kv_new.num_consumers = NumMmaThreads;
         | 
| 291 | 
            -
                    auto pipeline_k_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
         | 
| 292 | 
            -
                    if constexpr (!SameHeadDim) {
         | 
| 293 | 
            -
                        pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
         | 
| 294 | 
            -
                    }
         | 
| 295 | 
            -
                    auto pipeline_v_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
         | 
| 296 | 
            -
             | 
| 297 | 
            -
                    CollectiveMainloop mainloop;
         | 
| 298 | 
            -
                    CollectiveEpilogue epilogue;
         | 
| 299 | 
            -
             | 
| 300 | 
            -
                    const int num_heads = get<2>(params.mainloop.shape_Q);
         | 
| 301 | 
            -
                    Tensor gS_aux = make_tensor(make_gmem_ptr(params.mainloop.ptr_S_aux), make_shape(num_heads));
         | 
| 302 | 
            -
                    Tensor sS_aux = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_s_aux.data()), SmemLayoutSAux{});
         | 
| 303 | 
            -
             | 
| 304 | 
            -
                    if(params.mainloop.ptr_S_aux && threadIdx.x < num_heads) {
         | 
| 305 | 
            -
                        sS_aux(threadIdx.x) = gS_aux(threadIdx.x);
         | 
| 306 | 
            -
                    }
         | 
| 307 | 
            -
             | 
| 308 | 
            -
                    // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
         | 
| 309 | 
            -
                    if constexpr (size(ClusterShape{}) > 1) {
         | 
| 310 | 
            -
                        cute::cluster_arrive_relaxed();
         | 
| 311 | 
            -
                        cute::cluster_wait();
         | 
| 312 | 
            -
                    } else {
         | 
| 313 | 
            -
                        __syncthreads();
         | 
| 314 | 
            -
                    }
         | 
| 315 | 
            -
             | 
| 316 | 
            -
                    TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
         | 
| 317 | 
            -
             | 
| 318 | 
            -
                    if (warp_group_idx == 0) {  // Producer
         | 
| 319 | 
            -
                        cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
         | 
| 320 | 
            -
             | 
| 321 | 
            -
                        // The pipelines for AppendKV and main attention are different, since e.g. main attention
         | 
| 322 | 
            -
                        // might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load
         | 
| 323 | 
            -
                        // KV_new. Since the pipeline states are different, we have to manually sync to make
         | 
| 324 | 
            -
                        // sure the two pipelines don't race when accessing smem_k and smem_v.
         | 
| 325 | 
            -
                        PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();
         | 
| 326 | 
            -
                        PipelineState smem_pipe_write_new = cutlass::make_producer_start_state<MainloopPipelineKVNew>();
         | 
| 327 | 
            -
                        int work_idx = 0;
         | 
| 328 | 
            -
                        int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
         | 
| 329 | 
            -
                        static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
         | 
| 330 | 
            -
                        if constexpr (SingleProducerWarp) {
         | 
| 331 | 
            -
                            if (warp_idx_in_warpgroup != 0) { return; }
         | 
| 332 | 
            -
                        }
         | 
| 333 | 
            -
                        if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); }
         | 
| 334 | 
            -
             | 
| 335 | 
            -
                        cutlass::arch::wait_on_dependent_grids();
         | 
| 336 | 
            -
             | 
| 337 | 
            -
                        // Load Q, K, V
         | 
| 338 | 
            -
                        for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
         | 
| 339 | 
            -
                             work_tile_info.is_valid(params.scheduler);
         | 
| 340 | 
            -
                             work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
         | 
| 341 | 
            -
             | 
| 342 | 
            -
                            auto block_coord = work_tile_info.get_block_coord(params.scheduler);
         | 
| 343 | 
            -
                            SeqlenInfo_t seqlen_info{
         | 
| 344 | 
            -
                                get<2>(block_coord) /*bidb*/,
         | 
| 345 | 
            -
                                get<0>(params.mainloop.shape_Q),
         | 
| 346 | 
            -
                                !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
         | 
| 347 | 
            -
                                get<0>(params.mainloop.shape_K_new),
         | 
| 348 | 
            -
                                params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
         | 
| 349 | 
            -
                                params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
         | 
| 350 | 
            -
                                params.mainloop.seqlens_rotary
         | 
| 351 | 
            -
                            };
         | 
| 352 | 
            -
                            if constexpr (AppendKV) {
         | 
| 353 | 
            -
                                bool tile_new_valid = mainloop.load_kv_new(
         | 
| 354 | 
            -
                                    params.mainloop, pipeline_k_new, pipeline_v_new,
         | 
| 355 | 
            -
                                    smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx);
         | 
| 356 | 
            -
                                if (tile_new_valid) {
         | 
| 357 | 
            -
                                    // if (threadIdx.x == 0) { printf("Producer: Before sync\n"); }
         | 
| 358 | 
            -
                                    cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
         | 
| 359 | 
            -
                                    // if (threadIdx.x == 0) { printf("Producer: After sync\n"); }
         | 
| 360 | 
            -
                                }
         | 
| 361 | 
            -
                            }
         | 
| 362 | 
            -
                            auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() {
         | 
| 363 | 
            -
                                scheduler.prefetch_next_work(params.scheduler, work_tile_info);
         | 
| 364 | 
            -
                            };
         | 
| 365 | 
            -
                            // pipeline_vt won't be used if we don't need to transpose V.
         | 
| 366 | 
            -
                            mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write,
         | 
| 367 | 
            -
                                                     shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx);
         | 
| 368 | 
            -
                        }
         | 
| 369 | 
            -
                        mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx);
         | 
| 370 | 
            -
                    } else {  // Consumer
         | 
| 371 | 
            -
                        cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
         | 
| 372 | 
            -
             | 
| 373 | 
            -
                        // Initialize matmul objects.
         | 
| 374 | 
            -
                        TiledMmaPV tiled_mma_pv;
         | 
| 375 | 
            -
             | 
| 376 | 
            -
                        PipelineState smem_pipe_read;
         | 
| 377 | 
            -
                        PipelineState smem_pipe_read_new;
         | 
| 378 | 
            -
                        // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
         | 
| 379 | 
            -
                        // (like in Cutlass's gemm) because the read and release pipeline states are always the same.
         | 
| 380 | 
            -
             | 
| 381 | 
            -
                        scheduler.init_consumer();
         | 
| 382 | 
            -
                        mainloop.mma_init();
         | 
| 383 | 
            -
             | 
| 384 | 
            -
                        int work_idx = 0;
         | 
| 385 | 
            -
                        CUTLASS_PRAGMA_NO_UNROLL
         | 
| 386 | 
            -
                        for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
         | 
| 387 | 
            -
                             work_tile_info.is_valid(params.scheduler);
         | 
| 388 | 
            -
                             // get_next_work will be called before the epilogue
         | 
| 389 | 
            -
                             ) {
         | 
| 390 | 
            -
                            auto block_coord = work_tile_info.get_block_coord(params.scheduler);
         | 
| 391 | 
            -
                            int const bidb = get<2>(block_coord);
         | 
| 392 | 
            -
                            SeqlenInfo_t seqlen_info{
         | 
| 393 | 
            -
                                bidb,
         | 
| 394 | 
            -
                                get<0>(params.mainloop.shape_Q),
         | 
| 395 | 
            -
                                !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
         | 
| 396 | 
            -
                                get<0>(params.mainloop.shape_K_new),
         | 
| 397 | 
            -
                                params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
         | 
| 398 | 
            -
                                params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
         | 
| 399 | 
            -
                                params.mainloop.seqlens_rotary
         | 
| 400 | 
            -
                            };
         | 
| 401 | 
            -
                            if constexpr (AppendKV) {
         | 
| 402 | 
            -
                                bool tile_new_valid = mainloop.store_kv_new(
         | 
| 403 | 
            -
                                    params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new,
         | 
| 404 | 
            -
                                    threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord);
         | 
| 405 | 
            -
                                if (tile_new_valid) {
         | 
| 406 | 
            -
                                    // if (threadIdx.x == 128) { printf("Consumer: Before sync\n"); }
         | 
| 407 | 
            -
                                    // We need this sync so that the gmem write from the consumers is visible to the producer
         | 
| 408 | 
            -
                                    // that might do TMA read after that.
         | 
| 409 | 
            -
                                    asm volatile ("fence.proxy.async.global;");
         | 
| 410 | 
            -
                                    cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
         | 
| 411 | 
            -
                                    // arrive is enough, we don't need sync. The producer will sync, which means
         | 
| 412 | 
            -
                                    // after that sync we're guaranteed that the AppendKV pipeline have finished
         | 
| 413 | 
            -
                                    // loading and consumer smem_k and smem_v.
         | 
| 414 | 
            -
                                    // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); }
         | 
| 415 | 
            -
                                }
         | 
| 416 | 
            -
                            }
         | 
| 417 | 
            -
                            // If there's tanh softcap, the scaling will be done before tanh.
         | 
| 418 | 
            -
                            float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
         | 
| 419 | 
            -
                            if constexpr (Is_FP8 && !Has_softcap) {
         | 
| 420 | 
            -
                                int const bidh = get<1>(block_coord);
         | 
| 421 | 
            -
                                int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
         | 
| 422 | 
            -
                                float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
         | 
| 423 | 
            -
                                float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
         | 
| 424 | 
            -
                                softmax_scale_log2 *= q_descale * k_descale;
         | 
| 425 | 
            -
                            }
         | 
| 426 | 
            -
                            flash::Softmax<!LargeHeadDimV ? 2 * (2 * kBlockM / NumMmaThreads) : 2, /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
         | 
| 427 | 
            -
                            // Attention output (GEMM-II) accumulator.
         | 
| 428 | 
            -
                            Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{}));
         | 
| 429 | 
            -
                            bool tile_valid;
         | 
| 430 | 
            -
                            if constexpr (!LargeHeadDimV) {
         | 
| 431 | 
            -
                                tile_valid = mainloop.mma(
         | 
| 432 | 
            -
                                    params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
         | 
| 433 | 
            -
                                    tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
         | 
| 434 | 
            -
                            } else {  // mma_pv might not compile if !LargeHeadDimV
         | 
| 435 | 
            -
                                if (warp_group_idx == 1) {
         | 
| 436 | 
            -
                                    tile_valid = mainloop.mma(
         | 
| 437 | 
            -
                                        params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
         | 
| 438 | 
            -
                                        tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
         | 
| 439 | 
            -
                                } else {
         | 
| 440 | 
            -
                                    tile_valid = mainloop.mma_pv(
         | 
| 441 | 
            -
                                        params.mainloop, pipeline_v, smem_pipe_read,
         | 
| 442 | 
            -
                                        tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage);
         | 
| 443 | 
            -
                                }
         | 
| 444 | 
            -
                            }
         | 
| 445 | 
            -
                            // Do this here before the epilogue so that the next tile is ready to go.
         | 
| 446 | 
            -
                            work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info);
         | 
| 447 | 
            -
                            if constexpr (Split && Varlen) {
         | 
| 448 | 
            -
                                if (!work_tile_info.is_valid(params.scheduler)) {  // Last tile
         | 
| 449 | 
            -
                                    cutlass::arch::launch_dependent_grids();
         | 
| 450 | 
            -
                                }
         | 
| 451 | 
            -
                            }
         | 
| 452 | 
            -
                            if (tile_valid) {
         | 
| 453 | 
            -
                                // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
         | 
| 454 | 
            -
                                epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv,
         | 
| 455 | 
            -
                                               threadIdx.x - MmaThreadOffset, block_coord);
         | 
| 456 | 
            -
                            } else {
         | 
| 457 | 
            -
                                // Write 0 to gO and -inf to gLSE.
         | 
| 458 | 
            -
                                epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord);
         | 
| 459 | 
            -
                            }
         | 
| 460 | 
            -
                        }
         | 
| 461 | 
            -
                        epilogue.store_tail();
         | 
| 462 | 
            -
                    }
         | 
| 463 | 
            -
             | 
| 464 | 
            -
                }
         | 
| 465 | 
            -
             | 
| 466 | 
            -
            };
         | 
| 467 | 
            -
             | 
| 468 | 
            -
            } // namespace flash
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash_fwd_launch_template.h
    DELETED
    
    | @@ -1,231 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include "cute/tensor.hpp"
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            #include "cutlass/cutlass.h"
         | 
| 10 | 
            -
            #include "cutlass/device_kernel.h"  // For device_kernel
         | 
| 11 | 
            -
            #include <cutlass/kernel_hardware_info.h>
         | 
| 12 | 
            -
            #include "cutlass/cluster_launch.hpp"
         | 
| 13 | 
            -
            #include "cutlass/kernel_launch.h"
         | 
| 14 | 
            -
             | 
| 15 | 
            -
            #include "static_switch.h"
         | 
| 16 | 
            -
            #include "flash.h"
         | 
| 17 | 
            -
            #include "tile_size.h"
         | 
| 18 | 
            -
            #include "tile_scheduler.hpp"
         | 
| 19 | 
            -
            #include "flash_fwd_kernel_sm90.h"
         | 
| 20 | 
            -
            #include "flash_fwd_kernel_sm80.h"
         | 
| 21 | 
            -
            #include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
         | 
| 22 | 
            -
            #include "mainloop_fwd_sm80.hpp"
         | 
| 23 | 
            -
            #include "epilogue_fwd.hpp"
         | 
| 24 | 
            -
            #include "heuristics.h"
         | 
| 25 | 
            -
             | 
| 26 | 
            -
            using namespace cute;
         | 
| 27 | 
            -
             | 
| 28 | 
            -
            template <int Arch, int kHeadDim, int kHeadDimV, int ClusterM, typename Element, typename ElementOut,
         | 
| 29 | 
            -
                      bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool PagedKVNonTMA, bool AppendKV, bool HasQv,
         | 
| 30 | 
            -
                      bool PackGQA, bool Split, bool V_colmajor, bool Use_one_mma_wg>
         | 
| 31 | 
            -
            void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
         | 
| 32 | 
            -
                static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time");
         | 
| 33 | 
            -
                static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time");
         | 
| 34 | 
            -
                static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen");
         | 
| 35 | 
            -
                static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;
         | 
| 36 | 
            -
                static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor;
         | 
| 37 | 
            -
                using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
         | 
| 38 | 
            -
                using ElementS = cutlass::bfloat16_t;
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                // Can't use structured binding since it's not compatible with constexpr
         | 
| 41 | 
            -
                static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg);
         | 
| 42 | 
            -
                static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV);
         | 
| 43 | 
            -
                static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);
         | 
| 44 | 
            -
                static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);
         | 
| 45 | 
            -
                static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap);
         | 
| 46 | 
            -
                static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap);
         | 
| 47 | 
            -
                static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS);
         | 
| 48 | 
            -
                static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS);
         | 
| 49 | 
            -
                static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS);
         | 
| 50 | 
            -
             | 
| 51 | 
            -
                using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
         | 
| 52 | 
            -
                using TileShape_MNK_PV = cute::Shape<Int<kBlockM>, Int<kHeadDimV>, Int<kBlockN>>;
         | 
| 53 | 
            -
                using ClusterShape = cute::Shape<Int<ClusterM>, _1, _1>;
         | 
| 54 | 
            -
                using CollectiveMainloop = std::conditional_t<
         | 
| 55 | 
            -
                    Arch >= 90,
         | 
| 56 | 
            -
                    flash::CollectiveMainloopFwdSm90<kStages, ClusterShape, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, HasQv, MmaPV_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor, ElementS>,
         | 
| 57 | 
            -
                    flash::CollectiveMainloopFwdSm80<kNWarps, kStages, Q_in_regs, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm80, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, PackGQA, Split, ElementS>
         | 
| 58 | 
            -
                >;
         | 
| 59 | 
            -
                using CollectiveEpilogue = flash::CollectiveEpilogueFwd<TileShape_MNK_PV, ClusterShape, ElementOut, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, PackGQA, Split, FP8_TransposeV>;
         | 
| 60 | 
            -
             | 
| 61 | 
            -
                static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads;
         | 
| 62 | 
            -
                using SchedulerPersistent = std::conditional_t<Varlen,
         | 
| 63 | 
            -
                    flash::VarlenDynamicPersistentTileScheduler<kBlockM, CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>,
         | 
| 64 | 
            -
                    std::conditional_t<!Is_causal && !Is_local,
         | 
| 65 | 
            -
                        flash::StaticPersistentTileScheduler<Split>,
         | 
| 66 | 
            -
                        flash::DynamicPersistentTileScheduler<CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>
         | 
| 67 | 
            -
                    >
         | 
| 68 | 
            -
                >;
         | 
| 69 | 
            -
                using SchedulerSingleTile = flash::SingleTileScheduler<Varlen, Split, PackGQA, kBlockM>;
         | 
| 70 | 
            -
                // If Split then we probably don't have enough work for PersistentScheduler to be useful.
         | 
| 71 | 
            -
                // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better
         | 
| 72 | 
            -
                // since we'll avoid launching a bunch of thread blocks that immediately exit.
         | 
| 73 | 
            -
                // On Sm80, noncausal persistent seems a bit slower.
         | 
| 74 | 
            -
                static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split));
         | 
| 75 | 
            -
                using Scheduler = std::conditional_t<!UsePersistentScheduler, SchedulerSingleTile, SchedulerPersistent>;
         | 
| 76 | 
            -
                using AttnKernel = std::conditional_t<
         | 
| 77 | 
            -
                    Arch >= 90,
         | 
| 78 | 
            -
                    flash::enable_sm90_or_later<flash::FlashAttnFwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
         | 
| 79 | 
            -
                    flash::enable_sm80_to_sm89<flash::FlashAttnFwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
         | 
| 80 | 
            -
                >;
         | 
| 81 | 
            -
             | 
| 82 | 
            -
                bool const is_varlen_q = params.cu_seqlens_q;
         | 
| 83 | 
            -
                bool const is_varlen_k = params.cu_seqlens_k;
         | 
| 84 | 
            -
                bool const is_varlen_k_new = params.cu_seqlens_knew;
         | 
| 85 | 
            -
                int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
         | 
| 86 | 
            -
                int batch_q = !is_varlen_q ? params.b : 1;
         | 
| 87 | 
            -
                int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1;
         | 
| 88 | 
            -
                typename CollectiveMainloop::StrideV v_strides =
         | 
| 89 | 
            -
                    cute::conditional_return<!V_colmajor>(
         | 
| 90 | 
            -
                        make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0),
         | 
| 91 | 
            -
                        make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0));
         | 
| 92 | 
            -
                typename CollectiveMainloop::Arguments mainloop_args {
         | 
| 93 | 
            -
                    static_cast<Element const*>(params.q_ptr),
         | 
| 94 | 
            -
                    {seqlen_q, params.d, params.h, batch_q},  // shape_Q
         | 
| 95 | 
            -
                    {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0},  // stride_Q
         | 
| 96 | 
            -
                    static_cast<Element*>(params.k_ptr),
         | 
| 97 | 
            -
                    {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size,
         | 
| 98 | 
            -
                     params.d, params.h_k, !params.page_table ? batch_k : params.num_pages},  // shape_K
         | 
| 99 | 
            -
                    {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0},  // stride_K
         | 
| 100 | 
            -
                    static_cast<Element*>(params.v_ptr),
         | 
| 101 | 
            -
                    params.dv,  // headdim_v
         | 
| 102 | 
            -
                    v_strides,  // stride_V
         | 
| 103 | 
            -
                    static_cast<Element const*>(params.knew_ptr),
         | 
| 104 | 
            -
                    {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1},  // shape_K_new
         | 
| 105 | 
            -
                    {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0},  // stride_K_new
         | 
| 106 | 
            -
                    static_cast<Element const*>(params.vnew_ptr),
         | 
| 107 | 
            -
                    {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new
         | 
| 108 | 
            -
                    static_cast<Element const*>(params.qv_ptr),
         | 
| 109 | 
            -
                    {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0},  // stride_Qv
         | 
| 110 | 
            -
                    static_cast<Element const*>(params.rotary_cos_ptr),
         | 
| 111 | 
            -
                    {params.seqlen_k, params.rotary_dim / 2},  // shape_rotary, the seqlen shape doesn't matter
         | 
| 112 | 
            -
                    {params.rotary_dim / 2, _1{}},  // stride_rotary_cos
         | 
| 113 | 
            -
                    static_cast<Element const*>(params.rotary_sin_ptr),
         | 
| 114 | 
            -
                    {params.rotary_dim / 2, _1{}},  // stride_rotary_sin
         | 
| 115 | 
            -
                    params.is_rotary_interleaved,
         | 
| 116 | 
            -
                    params.page_table,
         | 
| 117 | 
            -
                    // if page_size is not set, avoid dividing by zero
         | 
| 118 | 
            -
                    {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table
         | 
| 119 | 
            -
                    {params.page_table_batch_stride, _1{}},  // stride_page_table
         | 
| 120 | 
            -
                    params.scale_softmax,
         | 
| 121 | 
            -
                    params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr,
         | 
| 122 | 
            -
                    {params.q_descale_batch_stride, params.q_descale_head_stride},
         | 
| 123 | 
            -
                    {params.k_descale_batch_stride, params.k_descale_head_stride},
         | 
| 124 | 
            -
                    {params.v_descale_batch_stride, params.v_descale_head_stride},
         | 
| 125 | 
            -
                    params.window_size_left, params.window_size_right,
         | 
| 126 | 
            -
                    params.softcap,
         | 
| 127 | 
            -
                    params.num_splits,
         | 
| 128 | 
            -
                    params.kv_batch_idx,
         | 
| 129 | 
            -
                    params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
         | 
| 130 | 
            -
                    params.seqused_q, params.seqused_k,
         | 
| 131 | 
            -
                    params.leftpad_k, params.seqlens_rotary,
         | 
| 132 | 
            -
                    static_cast<ElementS const*>(params.s_aux_ptr)
         | 
| 133 | 
            -
                };
         | 
| 134 | 
            -
                typename CollectiveEpilogue::Arguments epilogue_args {
         | 
| 135 | 
            -
                    static_cast<ElementOut*>(params.o_ptr),
         | 
| 136 | 
            -
                    {seqlen_q, params.dv, params.h, batch_q, params.num_splits},  // shape_O
         | 
| 137 | 
            -
                    {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O
         | 
| 138 | 
            -
                    static_cast<float*>(params.oaccum_ptr),
         | 
| 139 | 
            -
                    {params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial
         | 
| 140 | 
            -
                    static_cast<float*>(params.softmax_lse_ptr),
         | 
| 141 | 
            -
                    {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0},  // stride_LSE
         | 
| 142 | 
            -
                    static_cast<float*>(params.softmax_lseaccum_ptr),
         | 
| 143 | 
            -
                    {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q},  // stride_LSE_partial
         | 
| 144 | 
            -
                    params.h_k,
         | 
| 145 | 
            -
                    params.cu_seqlens_q, params.seqused_q
         | 
| 146 | 
            -
                };
         | 
| 147 | 
            -
             | 
| 148 | 
            -
                int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k);
         | 
| 149 | 
            -
                int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{}));
         | 
| 150 | 
            -
                num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{}));
         | 
| 151 | 
            -
                typename flash::TileSchedulerArguments scheduler_args {
         | 
| 152 | 
            -
                    num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits,
         | 
| 153 | 
            -
                    params.h / params.h_k,
         | 
| 154 | 
            -
                    params.seqlen_q,
         | 
| 155 | 
            -
                    params.seqlen_k, params.d, params.dv, sizeof(Element),
         | 
| 156 | 
            -
                    params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q,
         | 
| 157 | 
            -
                    // params.num_m_blocks_ptr,
         | 
| 158 | 
            -
                    params.num_splits_dynamic_ptr,
         | 
| 159 | 
            -
                };
         | 
| 160 | 
            -
             | 
| 161 | 
            -
                if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) {
         | 
| 162 | 
            -
                    prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/);
         | 
| 163 | 
            -
                    CHECK_CUDA_KERNEL_LAUNCH();
         | 
| 164 | 
            -
                }
         | 
| 165 | 
            -
             | 
| 166 | 
            -
                int device;
         | 
| 167 | 
            -
                CHECK_CUDA(cudaGetDevice(&device));
         | 
| 168 | 
            -
                typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
         | 
| 169 | 
            -
                    mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
         | 
| 170 | 
            -
                });
         | 
| 171 | 
            -
             | 
| 172 | 
            -
                dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
         | 
| 173 | 
            -
                dim3 block_dims = AttnKernel::get_block_shape();
         | 
| 174 | 
            -
                int smem_size = AttnKernel::SharedStorageSize;
         | 
| 175 | 
            -
                // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
         | 
| 176 | 
            -
                // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
         | 
| 177 | 
            -
                // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
         | 
| 178 | 
            -
                // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
         | 
| 179 | 
            -
                // Get the ptr to kernel function.
         | 
| 180 | 
            -
                if constexpr (size(ClusterShape{}) > 1) {
         | 
| 181 | 
            -
                    void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
         | 
| 182 | 
            -
                    if (smem_size >= 48 * 1024) {
         | 
| 183 | 
            -
                        CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
         | 
| 184 | 
            -
                    }
         | 
| 185 | 
            -
                    dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
         | 
| 186 | 
            -
                    cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
         | 
| 187 | 
            -
                    cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
         | 
| 188 | 
            -
                } else {
         | 
| 189 | 
            -
                    auto kernel = cutlass::device_kernel<AttnKernel>;
         | 
| 190 | 
            -
                    if (smem_size >= 48 * 1024) {
         | 
| 191 | 
            -
                        CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
         | 
| 192 | 
            -
                    }
         | 
| 193 | 
            -
                    // kernel<<<grid_dims, block_dims, smem_size, stream>>>(kernel_params);
         | 
| 194 | 
            -
                    cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params,
         | 
| 195 | 
            -
                                                       Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/);
         | 
| 196 | 
            -
                }
         | 
| 197 | 
            -
                CHECK_CUDA_KERNEL_LAUNCH();
         | 
| 198 | 
            -
            }
         | 
| 199 | 
            -
             | 
| 200 | 
            -
            template<int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
         | 
| 201 | 
            -
            void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) {
         | 
| 202 | 
            -
                static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported");
         | 
| 203 | 
            -
                static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
         | 
| 204 | 
            -
                using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
         | 
| 205 | 
            -
                CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
         | 
| 206 | 
            -
                    VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] {
         | 
| 207 | 
            -
                        static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1;
         | 
| 208 | 
            -
                        VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] {
         | 
| 209 | 
            -
                            BOOL_SWITCH(use_one_mma_wg(params), Use_one_mma_wg_, [&] {
         | 
| 210 | 
            -
                                // Avoid over compiliation by making sure this only get set if it is actually used, i.e. we currently only support one mma wg for 128 head dim and hopper
         | 
| 211 | 
            -
                                static constexpr bool Use_one_mma_wg = Use_one_mma_wg_ && Arch >= 90 && kHeadDim == 128;
         | 
| 212 | 
            -
             | 
| 213 | 
            -
                                // Only needed here to decide if we should use cluster
         | 
| 214 | 
            -
                                static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg)) : 128;
         | 
| 215 | 
            -
             | 
| 216 | 
            -
                                static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen;
         | 
| 217 | 
            -
                                BOOL_SWITCH(params.qv_ptr, HasQV_, [&] {
         | 
| 218 | 
            -
                                    static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256;
         | 
| 219 | 
            -
                                    APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {
         | 
| 220 | 
            -
                                        // Only use Cluster if number of tiles along seqlen_q is even and not varlen
         | 
| 221 | 
            -
                                        CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {
         | 
| 222 | 
            -
                                            static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1;
         | 
| 223 | 
            -
                                            run_flash_fwd<Arch, kHeadDim, kHeadDimV, ClusterM, T, T_out, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV && Varlen, HasQv, PackGQA, Split, V_colmajor, Use_one_mma_wg>(params, stream);
         | 
| 224 | 
            -
                                        });
         | 
| 225 | 
            -
                                    });
         | 
| 226 | 
            -
                                });
         | 
| 227 | 
            -
                            });
         | 
| 228 | 
            -
                        });
         | 
| 229 | 
            -
                    });
         | 
| 230 | 
            -
                });
         | 
| 231 | 
            -
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/flash_prepare_scheduler.cu
    DELETED
    
    | @@ -1,124 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "cutlass/fast_math.h"
         | 
| 6 | 
            -
            #include "cutlass/barrier.h"
         | 
| 7 | 
            -
            #include "cutlass/arch/barrier.h"
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            #include "cutlass/arch/grid_dependency_control.h"
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            #include "flash.h"
         | 
| 12 | 
            -
             | 
| 13 | 
            -
            namespace flash {
         | 
| 14 | 
            -
             | 
| 15 | 
            -
            __global__ void prepare_varlen_num_blocks_kernel(
         | 
| 16 | 
            -
                    int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static,
         | 
| 17 | 
            -
                    int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new,
         | 
| 18 | 
            -
                    int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr,
         | 
| 19 | 
            -
                    int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static,
         | 
| 20 | 
            -
                    cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod,
         | 
| 21 | 
            -
                    int* const tile_count_semaphore,
         | 
| 22 | 
            -
                    // int* const num_m_blocks_ptr,
         | 
| 23 | 
            -
                    int* const num_splits_dynamic_ptr,
         | 
| 24 | 
            -
                    bool enable_pdl) {
         | 
| 25 | 
            -
             | 
| 26 | 
            -
                static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1;
         | 
| 27 | 
            -
                static constexpr int kSmemSize = 1;
         | 
| 28 | 
            -
                // Assume that there's only one block in the grid
         | 
| 29 | 
            -
                __shared__ int total_blocks_smem[kSmemSize];
         | 
| 30 | 
            -
             | 
| 31 | 
            -
                // There's only 1 block in the grid, so might as well start launching the main attn kernel
         | 
| 32 | 
            -
                if (enable_pdl) { cutlass::arch::launch_dependent_grids(); }
         | 
| 33 | 
            -
             | 
| 34 | 
            -
                if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; }
         | 
| 35 | 
            -
                __syncthreads();
         | 
| 36 | 
            -
             | 
| 37 | 
            -
                if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; }
         | 
| 38 | 
            -
             | 
| 39 | 
            -
                int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
         | 
| 40 | 
            -
             | 
| 41 | 
            -
                auto get_num_m_blocks = [&](int bidb_start) {
         | 
| 42 | 
            -
                    int batch_idx = lane + bidb_start;
         | 
| 43 | 
            -
                    int seqlen;
         | 
| 44 | 
            -
                    if (seqused_q) {
         | 
| 45 | 
            -
                        seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0;
         | 
| 46 | 
            -
                    } else if (cu_seqlens_q) {
         | 
| 47 | 
            -
                        int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0;
         | 
| 48 | 
            -
                        int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
         | 
| 49 | 
            -
                        seqlen = next_cu_seqlen - cur_cu_seqlen;
         | 
| 50 | 
            -
                    } else {
         | 
| 51 | 
            -
                        seqlen = seqlen_q_static;
         | 
| 52 | 
            -
                    }
         | 
| 53 | 
            -
                    seqlen *= qhead_per_khead;
         | 
| 54 | 
            -
                    return batch_idx < num_batch && lane < kNumBatchPerWarp
         | 
| 55 | 
            -
                        ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0;
         | 
| 56 | 
            -
                };
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                auto get_num_n_blocks = [&](int bidb_start) {
         | 
| 59 | 
            -
                    int batch_idx = lane + bidb_start;
         | 
| 60 | 
            -
                    int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0;
         | 
| 61 | 
            -
                    int seqlen;
         | 
| 62 | 
            -
                    if (seqused_k) {
         | 
| 63 | 
            -
                        seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0;
         | 
| 64 | 
            -
                    } else if (cu_seqlens_k) {
         | 
| 65 | 
            -
                        int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0;
         | 
| 66 | 
            -
                        int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
         | 
| 67 | 
            -
                        seqlen = next_cu_seqlen - cur_cu_seqlen;
         | 
| 68 | 
            -
                    } else {
         | 
| 69 | 
            -
                        seqlen = seqlen_k_static;
         | 
| 70 | 
            -
                    }
         | 
| 71 | 
            -
                    int seqlen_new;
         | 
| 72 | 
            -
                    if (cu_seqlens_k_new) {
         | 
| 73 | 
            -
                        int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0;
         | 
| 74 | 
            -
                        int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1);
         | 
| 75 | 
            -
                        seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new;
         | 
| 76 | 
            -
                    } else {
         | 
| 77 | 
            -
                        seqlen_new = seqlen_k_new_static;
         | 
| 78 | 
            -
                    }
         | 
| 79 | 
            -
                    // if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); }
         | 
| 80 | 
            -
                    seqlen = seqlen - leftpad_k + seqlen_new;
         | 
| 81 | 
            -
                    return batch_idx < num_batch && lane < kNumBatchPerWarp
         | 
| 82 | 
            -
                        ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0;
         | 
| 83 | 
            -
                };
         | 
| 84 | 
            -
             | 
| 85 | 
            -
                int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp;
         | 
| 86 | 
            -
                int bidb_start = kNumBatchPerWarp * warp_idx;
         | 
| 87 | 
            -
                int num_m_blocks = get_num_m_blocks(bidb_start);
         | 
| 88 | 
            -
                int num_n_blocks = get_num_n_blocks(bidb_start);
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                int total_blocks = num_m_blocks * num_n_blocks;
         | 
| 91 | 
            -
                // Warp sum
         | 
| 92 | 
            -
                #pragma unroll
         | 
| 93 | 
            -
                for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) {
         | 
| 94 | 
            -
                    total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i);
         | 
| 95 | 
            -
                }
         | 
| 96 | 
            -
                if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); }
         | 
| 97 | 
            -
                __syncthreads();
         | 
| 98 | 
            -
                total_blocks = total_blocks_smem[0];
         | 
| 99 | 
            -
                // 10% margin
         | 
| 100 | 
            -
                int blocks_per_sm = static_cast<int>(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm)));
         | 
| 101 | 
            -
                // blocks_per_sm = std::max(1, blocks_per_sm);  // 1 is the minimum number of blocks per SM
         | 
| 102 | 
            -
                int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1);
         | 
| 103 | 
            -
                if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) {
         | 
| 104 | 
            -
                    num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic;
         | 
| 105 | 
            -
                    // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic);
         | 
| 106 | 
            -
                }
         | 
| 107 | 
            -
            }
         | 
| 108 | 
            -
             | 
| 109 | 
            -
            } // flash
         | 
| 110 | 
            -
             | 
| 111 | 
            -
            void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa,
         | 
| 112 | 
            -
                                           int blockM, int blockN, bool enable_pdl) {
         | 
| 113 | 
            -
                // Only support batch <= 992 (32 warps, each with 31 batches)
         | 
| 114 | 
            -
                int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k);
         | 
| 115 | 
            -
                flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>(
         | 
| 116 | 
            -
                    params.seqlen_q, params.seqlen_k, params.seqlen_knew,
         | 
| 117 | 
            -
                    params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
         | 
| 118 | 
            -
                    params.seqused_q, params.seqused_k, params.leftpad_k,
         | 
| 119 | 
            -
                    params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits,
         | 
| 120 | 
            -
                    cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN),
         | 
| 121 | 
            -
                    params.tile_count_semaphore,
         | 
| 122 | 
            -
                    // params.num_m_blocks_ptr,
         | 
| 123 | 
            -
                    params.num_splits_dynamic_ptr, enable_pdl);
         | 
| 124 | 
            -
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/heuristics.h
    DELETED
    
    | @@ -1,65 +0,0 @@ | |
| 1 | 
            -
            /******************************************************************************
         | 
| 2 | 
            -
             * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 3 | 
            -
             ******************************************************************************/
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #pragma once
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #include <vector>
         | 
| 8 | 
            -
            #include "flash.h"
         | 
| 9 | 
            -
             | 
| 10 | 
            -
            inline bool use_one_mma_wg(Flash_fwd_params const& params) {
         | 
| 11 | 
            -
                return params.arch >= 90 && params.d == 128 && 
         | 
| 12 | 
            -
                    params.seqlen_q * (!params.pack_gqa ? 1 : params.h / params.h_k) <= 64;
         | 
| 13 | 
            -
            };
         | 
| 14 | 
            -
             | 
| 15 | 
            -
            inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) {
         | 
| 16 | 
            -
                // If varlen, we don't actually know seqlen_q but only max_seqlen_q.
         | 
| 17 | 
            -
                if (varlen_q) return true;
         | 
| 18 | 
            -
                // Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM
         | 
| 19 | 
            -
                auto round_up = [](int a, int b) { return (a + b - 1) / b * b; };
         | 
| 20 | 
            -
                float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM));
         | 
| 21 | 
            -
                float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM));
         | 
| 22 | 
            -
                return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency;
         | 
| 23 | 
            -
            };
         | 
| 24 | 
            -
             | 
| 25 | 
            -
            // Find the number of splits that maximizes the occupancy. For example, if we have
         | 
| 26 | 
            -
            // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
         | 
| 27 | 
            -
            // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
         | 
| 28 | 
            -
            // splits as that would incur more HBM reads/writes.
         | 
| 29 | 
            -
            // So we find the best efficiency, then find the smallest number of splits that gets 85%
         | 
| 30 | 
            -
            // of the best efficiency.
         | 
| 31 | 
            -
            inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) {
         | 
| 32 | 
            -
                // If we have enough to almost fill the SMs, then just use 1 split
         | 
| 33 | 
            -
                // However, in the case of super long seqlen where each head of KV doesn't even fit into
         | 
| 34 | 
            -
                // L2 (we assume that L2 size is 50MB), we want to split.
         | 
| 35 | 
            -
                if (total_mblocks >= 0.8f * num_SMs) {
         | 
| 36 | 
            -
                    int const size_l2 = 50 * 1024 * 1024;
         | 
| 37 | 
            -
                    // Only split if there are enough queries to go over the KV at least twice
         | 
| 38 | 
            -
                    // Don't split if causal
         | 
| 39 | 
            -
                    if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) {
         | 
| 40 | 
            -
                        return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits);
         | 
| 41 | 
            -
                    } else {
         | 
| 42 | 
            -
                        return 1;
         | 
| 43 | 
            -
                    }
         | 
| 44 | 
            -
                }
         | 
| 45 | 
            -
                // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
         | 
| 46 | 
            -
                if (num_n_blocks <= 4) { return 1; }
         | 
| 47 | 
            -
                max_splits = std::min({max_splits, num_SMs, num_n_blocks});
         | 
| 48 | 
            -
                float max_efficiency = 0.f;
         | 
| 49 | 
            -
                std::vector<float> efficiency;
         | 
| 50 | 
            -
                efficiency.reserve(max_splits);
         | 
| 51 | 
            -
                for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
         | 
| 52 | 
            -
                    float n_waves = float(total_mblocks * num_splits) / num_SMs;
         | 
| 53 | 
            -
                    float eff = n_waves / ceil(n_waves);
         | 
| 54 | 
            -
                    // printf("num_splits = %d, eff = %f\n", num_splits, eff);
         | 
| 55 | 
            -
                    if (eff > max_efficiency) { max_efficiency = eff; }
         | 
| 56 | 
            -
                    efficiency.push_back(eff);
         | 
| 57 | 
            -
                }
         | 
| 58 | 
            -
                for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
         | 
| 59 | 
            -
                    if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
         | 
| 60 | 
            -
                        // printf("num_splits chosen = %d\n", num_splits);
         | 
| 61 | 
            -
                        return num_splits;
         | 
| 62 | 
            -
                    }
         | 
| 63 | 
            -
                }
         | 
| 64 | 
            -
                return 1;
         | 
| 65 | 
            -
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu
    DELETED
    
    | @@ -1,18 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_SM8x
         | 
| 8 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 9 | 
            -
            template<>
         | 
| 10 | 
            -
            void run_mha_bwd_<80, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 11 | 
            -
                run_mha_bwd_hdim128<80, cutlass::bfloat16_t, false>(params, stream);
         | 
| 12 | 
            -
            }
         | 
| 13 | 
            -
            template<>
         | 
| 14 | 
            -
            void run_mha_bwd_<86, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 15 | 
            -
                run_mha_bwd_hdim128<86, cutlass::bfloat16_t, false>(params, stream);
         | 
| 16 | 
            -
            }
         | 
| 17 | 
            -
            #endif
         | 
| 18 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu
    DELETED
    
    | @@ -1,12 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 8 | 
            -
            template<>
         | 
| 9 | 
            -
            void run_mha_bwd_<90, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 10 | 
            -
                run_mha_bwd_hdim128<90, cutlass::bfloat16_t, false>(params, stream);
         | 
| 11 | 
            -
            }
         | 
| 12 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu
    DELETED
    
    | @@ -1,18 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_SM8x
         | 
| 8 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 9 | 
            -
            template<>
         | 
| 10 | 
            -
            void run_mha_bwd_<80, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 11 | 
            -
                run_mha_bwd_hdim128<80, cutlass::bfloat16_t, true>(params, stream);
         | 
| 12 | 
            -
            }
         | 
| 13 | 
            -
            template<>
         | 
| 14 | 
            -
            void run_mha_bwd_<86, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 15 | 
            -
                run_mha_bwd_hdim128<86, cutlass::bfloat16_t, true>(params, stream);
         | 
| 16 | 
            -
            }
         | 
| 17 | 
            -
            #endif
         | 
| 18 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu
    DELETED
    
    | @@ -1,12 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 8 | 
            -
            template<>
         | 
| 9 | 
            -
            void run_mha_bwd_<90, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 10 | 
            -
                run_mha_bwd_hdim128<90, cutlass::bfloat16_t, true>(params, stream);
         | 
| 11 | 
            -
            }
         | 
| 12 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu
    DELETED
    
    | @@ -1,6 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_hdim128_bf16_sm90.cu"
         | 
| 6 | 
            -
            #include "flash_bwd_hdim128_bf16_softcap_sm90.cu"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu
    DELETED
    
    | @@ -1,18 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_SM8x
         | 
| 8 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 9 | 
            -
            template<>
         | 
| 10 | 
            -
            void run_mha_bwd_<80, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 11 | 
            -
                run_mha_bwd_hdim128<80, cutlass::half_t, false>(params, stream);
         | 
| 12 | 
            -
            }
         | 
| 13 | 
            -
            template<>
         | 
| 14 | 
            -
            void run_mha_bwd_<86, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 15 | 
            -
                run_mha_bwd_hdim128<86, cutlass::half_t, false>(params, stream);
         | 
| 16 | 
            -
            }
         | 
| 17 | 
            -
            #endif
         | 
| 18 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu
    DELETED
    
    | @@ -1,12 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 8 | 
            -
            template<>
         | 
| 9 | 
            -
            void run_mha_bwd_<90, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 10 | 
            -
                run_mha_bwd_hdim128<90, cutlass::half_t, false>(params, stream);
         | 
| 11 | 
            -
            }
         | 
| 12 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu
    DELETED
    
    | @@ -1,18 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_SM8x
         | 
| 8 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 9 | 
            -
            template<>
         | 
| 10 | 
            -
            void run_mha_bwd_<80, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 11 | 
            -
                run_mha_bwd_hdim128<80, cutlass::half_t, true>(params, stream);
         | 
| 12 | 
            -
            }
         | 
| 13 | 
            -
            template<>
         | 
| 14 | 
            -
            void run_mha_bwd_<86, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 15 | 
            -
                run_mha_bwd_hdim128<86, cutlass::half_t, true>(params, stream);
         | 
| 16 | 
            -
            }
         | 
| 17 | 
            -
            #endif
         | 
| 18 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu
    DELETED
    
    | @@ -1,12 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM128
         | 
| 8 | 
            -
            template<>
         | 
| 9 | 
            -
            void run_mha_bwd_<90, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 10 | 
            -
                run_mha_bwd_hdim128<90, cutlass::half_t, true>(params, stream);
         | 
| 11 | 
            -
            }
         | 
| 12 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu
    DELETED
    
    | @@ -1,6 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_hdim128_fp16_sm90.cu"
         | 
| 6 | 
            -
            #include "flash_bwd_hdim128_fp16_softcap_sm90.cu"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu
    DELETED
    
    | @@ -1,18 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_SM8x
         | 
| 8 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 9 | 
            -
            template<>
         | 
| 10 | 
            -
            void run_mha_bwd_<80, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 11 | 
            -
                run_mha_bwd_hdim192<80, cutlass::bfloat16_t, false>(params, stream);
         | 
| 12 | 
            -
            }
         | 
| 13 | 
            -
            template<>
         | 
| 14 | 
            -
            void run_mha_bwd_<86, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 15 | 
            -
                run_mha_bwd_hdim192<86, cutlass::bfloat16_t, false>(params, stream);
         | 
| 16 | 
            -
            }
         | 
| 17 | 
            -
            #endif
         | 
| 18 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu
    DELETED
    
    | @@ -1,12 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 8 | 
            -
            template<>
         | 
| 9 | 
            -
            void run_mha_bwd_<90, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 10 | 
            -
                run_mha_bwd_hdim192<90, cutlass::bfloat16_t, false>(params, stream);
         | 
| 11 | 
            -
            }
         | 
| 12 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu
    DELETED
    
    | @@ -1,18 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_SM8x
         | 
| 8 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 9 | 
            -
            template<>
         | 
| 10 | 
            -
            void run_mha_bwd_<80, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 11 | 
            -
                run_mha_bwd_hdim192<80, cutlass::bfloat16_t, true>(params, stream);
         | 
| 12 | 
            -
            }
         | 
| 13 | 
            -
            template<>
         | 
| 14 | 
            -
            void run_mha_bwd_<86, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 15 | 
            -
                run_mha_bwd_hdim192<86, cutlass::bfloat16_t, true>(params, stream);
         | 
| 16 | 
            -
            }
         | 
| 17 | 
            -
            #endif
         | 
| 18 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu
    DELETED
    
    | @@ -1,12 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 8 | 
            -
            template<>
         | 
| 9 | 
            -
            void run_mha_bwd_<90, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 10 | 
            -
                run_mha_bwd_hdim192<90, cutlass::bfloat16_t, true>(params, stream);
         | 
| 11 | 
            -
            }
         | 
| 12 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu
    DELETED
    
    | @@ -1,6 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_hdim192_bf16_sm90.cu"
         | 
| 6 | 
            -
            #include "flash_bwd_hdim192_bf16_softcap_sm90.cu"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu
    DELETED
    
    | @@ -1,18 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_SM8x
         | 
| 8 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 9 | 
            -
            template<>
         | 
| 10 | 
            -
            void run_mha_bwd_<80, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 11 | 
            -
                run_mha_bwd_hdim192<80, cutlass::half_t, false>(params, stream);
         | 
| 12 | 
            -
            }
         | 
| 13 | 
            -
            template<>
         | 
| 14 | 
            -
            void run_mha_bwd_<86, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 15 | 
            -
                run_mha_bwd_hdim192<86, cutlass::half_t, false>(params, stream);
         | 
| 16 | 
            -
            }
         | 
| 17 | 
            -
            #endif
         | 
| 18 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu
    DELETED
    
    | @@ -1,12 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 8 | 
            -
            template<>
         | 
| 9 | 
            -
            void run_mha_bwd_<90, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 10 | 
            -
                run_mha_bwd_hdim192<90, cutlass::half_t, false>(params, stream);
         | 
| 11 | 
            -
            }
         | 
| 12 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu
    DELETED
    
    | @@ -1,18 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_SM8x
         | 
| 8 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 9 | 
            -
            template<>
         | 
| 10 | 
            -
            void run_mha_bwd_<80, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 11 | 
            -
                run_mha_bwd_hdim192<80, cutlass::half_t, true>(params, stream);
         | 
| 12 | 
            -
            }
         | 
| 13 | 
            -
            template<>
         | 
| 14 | 
            -
            void run_mha_bwd_<86, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 15 | 
            -
                run_mha_bwd_hdim192<86, cutlass::half_t, true>(params, stream);
         | 
| 16 | 
            -
            }
         | 
| 17 | 
            -
            #endif
         | 
| 18 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu
    DELETED
    
    | @@ -1,12 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM192
         | 
| 8 | 
            -
            template<>
         | 
| 9 | 
            -
            void run_mha_bwd_<90, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 10 | 
            -
                run_mha_bwd_hdim192<90, cutlass::half_t, true>(params, stream);
         | 
| 11 | 
            -
            }
         | 
| 12 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu
    DELETED
    
    | @@ -1,6 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_hdim192_fp16_sm90.cu"
         | 
| 6 | 
            -
            #include "flash_bwd_hdim192_fp16_softcap_sm90.cu"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu
    DELETED
    
    | @@ -1,18 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_SM8x
         | 
| 8 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM256
         | 
| 9 | 
            -
            template<>
         | 
| 10 | 
            -
            void run_mha_bwd_<80, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 11 | 
            -
                run_mha_bwd_hdim256<80, cutlass::bfloat16_t, false>(params, stream);
         | 
| 12 | 
            -
            }
         | 
| 13 | 
            -
            template<>
         | 
| 14 | 
            -
            void run_mha_bwd_<86, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 15 | 
            -
                run_mha_bwd_hdim256<86, cutlass::bfloat16_t, false>(params, stream);
         | 
| 16 | 
            -
            }
         | 
| 17 | 
            -
            #endif
         | 
| 18 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu
    DELETED
    
    | @@ -1,12 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM256
         | 
| 8 | 
            -
            template<>
         | 
| 9 | 
            -
            void run_mha_bwd_<90, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 10 | 
            -
                run_mha_bwd_hdim256<90, cutlass::bfloat16_t, false>(params, stream);
         | 
| 11 | 
            -
            }
         | 
| 12 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu
    DELETED
    
    | @@ -1,18 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_SM8x
         | 
| 8 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM256
         | 
| 9 | 
            -
            template<>
         | 
| 10 | 
            -
            void run_mha_bwd_<80, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 11 | 
            -
                run_mha_bwd_hdim256<80, cutlass::bfloat16_t, true>(params, stream);
         | 
| 12 | 
            -
            }
         | 
| 13 | 
            -
            template<>
         | 
| 14 | 
            -
            void run_mha_bwd_<86, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 15 | 
            -
                run_mha_bwd_hdim256<86, cutlass::bfloat16_t, true>(params, stream);
         | 
| 16 | 
            -
            }
         | 
| 17 | 
            -
            #endif
         | 
| 18 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu
    DELETED
    
    | @@ -1,12 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM256
         | 
| 8 | 
            -
            template<>
         | 
| 9 | 
            -
            void run_mha_bwd_<90, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 10 | 
            -
                run_mha_bwd_hdim256<90, cutlass::bfloat16_t, true>(params, stream);
         | 
| 11 | 
            -
            }
         | 
| 12 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu
    DELETED
    
    | @@ -1,6 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_hdim256_bf16_sm90.cu"
         | 
| 6 | 
            -
            #include "flash_bwd_hdim256_bf16_softcap_sm90.cu"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu
    DELETED
    
    | @@ -1,18 +0,0 @@ | |
| 1 | 
            -
            // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
         | 
| 2 | 
            -
            // Splitting the different template instantiations to different files to speed up compilation.
         | 
| 3 | 
            -
            // This file is auto-generated. See "generate_kernels.py"
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            #include "flash_bwd_launch_template.h"
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            #ifndef FLASHATTENTION_DISABLE_SM8x
         | 
| 8 | 
            -
            #ifndef FLASHATTENTION_DISABLE_HDIM256
         | 
| 9 | 
            -
            template<>
         | 
| 10 | 
            -
            void run_mha_bwd_<80, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 11 | 
            -
                run_mha_bwd_hdim256<80, cutlass::half_t, false>(params, stream);
         | 
| 12 | 
            -
            }
         | 
| 13 | 
            -
            template<>
         | 
| 14 | 
            -
            void run_mha_bwd_<86, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
         | 
| 15 | 
            -
                run_mha_bwd_hdim256<86, cutlass::half_t, false>(params, stream);
         | 
| 16 | 
            -
            }
         | 
| 17 | 
            -
            #endif
         | 
| 18 | 
            -
            #endif
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 

