Lekr0 commited on
Commit
62dca4c
·
verified ·
1 Parent(s): 0b9402c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. progress/SpecForge/.devcontainer/Dockerfile +32 -0
  2. progress/SpecForge/.devcontainer/devcontainer.json +30 -0
  3. progress/SpecForge/.github/CODEOWNERS +11 -0
  4. progress/SpecForge/.github/pull_request_template.md +30 -0
  5. progress/SpecForge/assets/logo.svg +0 -0
  6. progress/SpecForge/benchmarks/README.md +67 -0
  7. progress/SpecForge/benchmarks/__init__.py +3 -0
  8. progress/SpecForge/benchmarks/bench_eagle3.py +268 -0
  9. progress/SpecForge/benchmarks/benchmarker/__init__.py +29 -0
  10. progress/SpecForge/benchmarks/benchmarker/aime.py +133 -0
  11. progress/SpecForge/benchmarks/benchmarker/base.py +218 -0
  12. progress/SpecForge/benchmarks/benchmarker/ceval.py +267 -0
  13. progress/SpecForge/benchmarks/benchmarker/financeqa.py +59 -0
  14. progress/SpecForge/benchmarks/benchmarker/gpqa.py +85 -0
  15. progress/SpecForge/benchmarks/benchmarker/gsm8k.py +99 -0
  16. progress/SpecForge/benchmarks/benchmarker/humaneval.py +188 -0
  17. progress/SpecForge/benchmarks/benchmarker/livecodebench.py +46 -0
  18. progress/SpecForge/benchmarks/benchmarker/math500.py +122 -0
  19. progress/SpecForge/benchmarks/benchmarker/mmlu.py +82 -0
  20. progress/SpecForge/benchmarks/benchmarker/mmstar.py +185 -0
  21. progress/SpecForge/benchmarks/benchmarker/mtbench.py +59 -0
  22. progress/SpecForge/benchmarks/benchmarker/registry.py +31 -0
  23. progress/SpecForge/benchmarks/benchmarker/simpleqa.py +42 -0
  24. progress/SpecForge/benchmarks/benchmarker/utils.py +273 -0
  25. progress/SpecForge/cache/compiled_kernels/26/c26l7dxpqbfol7d62sqakxdv4rgyh27yhm4hrctevbkw5t6kekia.py +799 -0
  26. progress/SpecForge/cache/compiled_kernels/2d/c2d4e47kqxxnp6455gvkteqq3r336462zkbitosyeko6znxktn2b.py +879 -0
  27. progress/SpecForge/cache/compiled_kernels/2g/c2gswut4q57fp2ueybipg5qfqiy5coitofujwdnvqdwhr7nbvnyq.py +534 -0
  28. progress/SpecForge/cache/compiled_kernels/2j/4b74fa21eaaf86b6290185f6fe50aec9b905d858a087238ceddb52477f3f6acb.best_config +1 -0
  29. progress/SpecForge/cache/compiled_kernels/2j/c2j3mtk3thi6sn2hxiuhuigjw43spiu74mxdervpgpfrtos7u2qh.py +28 -0
  30. progress/SpecForge/cache/compiled_kernels/2n/c2ngvuchx6agpdr6v7awl3qgblaehfzaauoxn6camwvtk7syoxsk.py +715 -0
  31. progress/SpecForge/cache/compiled_kernels/2n/c2nooi7ekpz4qvmvghggbegd5cyfspb27jmq2snbi26zbrpoibnx.py +48 -0
  32. progress/SpecForge/cache/compiled_kernels/2n/d17ff4e7bb44e5ae89a267ef332bb7c074804ce0942fc0694c3ef15b05f7854a.best_config +1 -0
  33. progress/SpecForge/cache/compiled_kernels/2o/c2oashzxz74kzyuwo67tuhk32cike37ysabriftachdv7lf2qxgs.py +799 -0
  34. progress/SpecForge/cache/compiled_kernels/2v/c2vob47d7sxpitzmofyr55f5hvxsitxjhpyv5hdiqcdjgbwmxk76.py +799 -0
  35. progress/SpecForge/cache/compiled_kernels/2y/c2yhndikcsebqfmbw7l44gmcdoyw7ogaqt7quyeygz3mp5w6u6ke.py +715 -0
  36. progress/SpecForge/cache/compiled_kernels/2z/c2zdv5arszdl6ednyphqfnib6jwgzomr6zt6536b7gq75kp67uvh.py +1046 -0
  37. progress/SpecForge/cache/compiled_kernels/2z/c2zqq6qyjomc7iflknbqr7yjdhjux47hzv4nnsi5qfbeqglaip2h.py +707 -0
  38. progress/SpecForge/cache/compiled_kernels/32/8d96bbe05a966b7e7756831f09a79e31bf46fad0952af86f36d75557fc1735e8.best_config +1 -0
  39. progress/SpecForge/cache/compiled_kernels/32/c32pbcuz72bjfnkzvckfbbzlzuupc5yxl7t47b3qf74mmk5g2d2z.py +27 -0
  40. progress/SpecForge/cache/compiled_kernels/3b/a0a6b043ab548fdf71e72bbdf5daab7f72e9ed11a9ad9f8824a6263bb6bc5081.best_config +1 -0
  41. progress/SpecForge/cache/compiled_kernels/3b/c3bqw7dk7k6dcdrp3ycrthotye7y6zb26752jl4lwmfgaybpvr6y.py +27 -0
  42. progress/SpecForge/cache/compiled_kernels/3f/3f6057605b157d44fd56f748226a63975b79198f94871188e73e46cd6c7f8792.best_config +1 -0
  43. progress/SpecForge/cache/compiled_kernels/3f/c3fttv7enp2yvnla3r6jkk4galt2qdpxw577ghvkmmx6zqaqla74.py +54 -0
  44. progress/SpecForge/cache/compiled_kernels/3n/c3nlaqknekmjv2zuxzow4rf42v3gorxnfp6uod3dg3ic5ibp6yp3.py +715 -0
  45. progress/SpecForge/cache/compiled_kernels/3q/c3qbvcsx2w7qss2v3eocuadgz6t35joo33bflzqkxzzj747zcjpk.py +51 -0
  46. progress/SpecForge/cache/compiled_kernels/3q/fc5920467dd1501963c976e2b895fc37747fdebfa098fff912209055f3a31828.best_config +1 -0
  47. progress/SpecForge/cache/compiled_kernels/3r/c3rkwwyedldrjz6sidtx5huqcsdgpdpu4xndmm6h4e4boo6cbg2w.py +702 -0
  48. progress/SpecForge/cache/compiled_kernels/3z/c3zi2pt6zmbthc6ythgt5p4ednhp6m24gpscb2pt6adf6xojetua.py +799 -0
  49. progress/SpecForge/cache/compiled_kernels/3z/c3zilfzjywngbdehwphwkhzpt6qcv6jecvzdajl2d5hb73xe6yzw.py +582 -0
  50. progress/SpecForge/cache/compiled_kernels/4a/7887d45b1aa6124e232769adbe995f9cc2af0dd187cb9928540172d82c7b8631.best_config +1 -0
progress/SpecForge/.devcontainer/Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM lmsysorg/sglang:dev
2
+
3
+ # Create non-root user with specified UID and GID
4
+ # NOTE: Replace with your own UID and GID. This is a workaround from https://github.com/microsoft/vscode-remote-release/issues/49#issuecomment-489060908.
5
+ ARG HOST_UID=1003
6
+ ARG HOST_GID=1003
7
+ RUN groupadd -g $HOST_GID devuser && \
8
+ useradd -m -u $HOST_UID -g $HOST_GID -s /bin/zsh devuser
9
+
10
+ # Give devuser sudo access
11
+ RUN apt-get update && apt-get install -y sudo && \
12
+ echo "devuser ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/devuser && \
13
+ rm -rf /var/lib/apt/lists/* && \
14
+ apt-get clean
15
+
16
+ # Set up oh-my-zsh for devuser
17
+ RUN cp -r /root/.oh-my-zsh /home/devuser/.oh-my-zsh && \
18
+ cp /root/.zshrc /home/devuser/.zshrc && \
19
+ cp /root/.vimrc /home/devuser/.vimrc && \
20
+ cp /root/.tmux.conf /home/devuser/.tmux.conf && \
21
+ sed -i 's|/root/.oh-my-zsh|/home/devuser/.oh-my-zsh|g' /home/devuser/.zshrc && \
22
+ chown -R devuser:devuser /home/devuser/
23
+
24
+ # Set workspace directory and ownership
25
+ WORKDIR /sgl-workspace/sglang
26
+ RUN chown -R devuser:devuser /sgl-workspace
27
+
28
+ # Switch to devuser
29
+ USER devuser
30
+
31
+ # Install rust
32
+ RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
progress/SpecForge/.devcontainer/devcontainer.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "sglang",
3
+ "build": {
4
+ "dockerfile": "Dockerfile"
5
+ },
6
+ "remoteUser": "devuser",
7
+ "customizations": {
8
+ "vscode": {
9
+ "extensions": [
10
+ // Python development
11
+ "ms-python.python",
12
+ "charliermarsh.ruff",
13
+ // Rust development
14
+ "rust-lang.rust-analyzer",
15
+ "tamasfe.even-better-toml"
16
+ ]
17
+ }
18
+ },
19
+ "forwardPorts": [],
20
+ "runArgs": [
21
+ "--gpus",
22
+ "all"
23
+ ],
24
+ // The two lines below ensures that your local changes in the sglang
25
+ // repo is automatically synced to the sglang pip package installed
26
+ // in the dev docker container. You can remove / comment out these
27
+ // two lines if you prefer to sync code changes manually.
28
+ "workspaceMount": "source=${localWorkspaceFolder},target=/sgl-workspace/specforge,type=bind",
29
+ "workspaceFolder": "/sgl-workspace/specforge"
30
+ }
progress/SpecForge/.github/CODEOWNERS ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .github @FrankLeeeee
2
+ /specforge/core @FrankLeeeee
3
+ /specforge/data @zyksir @sleepcoo @shuaills
4
+ /specforge/layers @FrankLeeeee @FlamingoPg @sleepcoo @shuaills
5
+ /specforge/modeling @FlamingoPg @sleepcoo @shuaills @FrankLeeeee
6
+ /tests @FrankLeeeee
7
+ /assets @FrankLeeeee @zhyncs
8
+ /examples @shuaills @sleepcoo @FlamingoPg
9
+ /configs @FrankLeeeee @FlamingoPg
10
+ /benchmarks @FrankLeeeee
11
+ /scripts @shuaills @sleepcoo @FlamingoPg
progress/SpecForge/.github/pull_request_template.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- Thank you for your contribution! We appreciate it. The following guidelines will help improve your pull request and facilitate feedback. If anything is unclear, don't hesitate to submit your pull request and ask the maintainers for assistance. -->
2
+
3
+ ## Motivation
4
+
5
+ <!-- Explain the purpose of this PR and the goals it aims to achieve. -->
6
+
7
+ ## Modifications
8
+
9
+ <!-- Describe the changes made in this PR. -->
10
+
11
+ ## Related Issues
12
+
13
+ <!-- Link to any related issues here. e.g. "Fixes #123" or "Closes #456" -->
14
+
15
+ ## Accuracy Test
16
+
17
+ <!-- If this PR affects model-side code (e.g., kernels, model architecture), please provide accuracy test results. Ref: https://docs.sglang.ai/references/accuracy_evaluation.html -->
18
+
19
+ ## Benchmark & Profiling
20
+
21
+ <!-- If this PR is expected to impact performance, please provide benchmark and profiling results. Ref: https://docs.sglang.ai/references/benchmark_and_profiling.html -->
22
+
23
+ ## Checklist
24
+
25
+ - [ ] Format your code according to the [Code Formatting with Pre-Commit](https://docs.sglang.ai/references/contribution_guide.html#code-formatting-with-pre-commit).
26
+ - [ ] Add unit tests as outlined in the [Running Unit Tests](https://docs.sglang.ai/references/contribution_guide.html#running-unit-tests-adding-to-ci).
27
+ - [ ] Update documentation / docstrings / example tutorials as needed, according to [Writing Documentation](https://docs.sglang.ai/references/contribution_guide.html#writing-documentation-running-docs-ci).
28
+ - [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to [Benchmark and Profiling](https://docs.sglang.ai/references/benchmark_and_profiling.html) and [Accuracy Results](https://docs.sglang.ai/references/accuracy_evaluation.html).
29
+ - [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
30
+ - [ ] Please feel free to join our Slack channel at https://sgl-fru7574.slack.com/archives/C09784E3EN6 to discuss your PR.
progress/SpecForge/assets/logo.svg ADDED
progress/SpecForge/benchmarks/README.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Benchmarking for Speculative Decoding
2
+
3
+ ## Overview
4
+
5
+ We provided a unified script to test the performance of the Speculative Decoding with EAGLE3 algorithm on multiple datasets. You can follow the steps below to run the benchmarks.
6
+
7
+ ## Run Benchmarks
8
+
9
+ ### Launch SGLang and Benchmarker Concurrently
10
+
11
+ `bench_eagle3.py` can help you launch a SGLang server process and a Benchmarking process concurrently. In this way, you don't have to launch the SGLang server manually, this script will manually handle the SGLang launch under different speculative decoding configurations. Some important arguments are:
12
+ - `--model-path`: the path to the target model.
13
+ - `--speculative-draft-model-path`: the path to the draft model.
14
+ - `--port`: the port to launch the SGLang server.
15
+ - `--trust-remote-code`: trust the remote code.
16
+ - `--mem-fraction-static`: the memory fraction for the static memory.
17
+ - `--tp-size`: the tensor parallelism size.
18
+ - `--attention-backend`: the attention backend.
19
+ - `--config-list`: the list of speculative decoding configuration to test, the format is `<batch-size>,<num-steps>,<topk>,<num-draft-tokens>`.
20
+ - `--benchmark-list`: the list of benchmarks to test, the format is `<benchmark-name>:<num-prompts>:<subset>`.
21
+
22
+ ```shell
23
+ python3 bench_eagle3.py \
24
+ --model-path meta-llama/Llama-3.1-8B-Instruct \
25
+ --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
26
+ --port 30000 \
27
+ --trust-remote-code \
28
+ --mem-fraction-static 0.8 \
29
+ --tp-size 1 \
30
+ --attention-backend fa3 \
31
+ --config-list 1,0,0,0 1,3,1,4 \
32
+ --benchmark-list mtbench gsm8k:5 ceval:5:accountant \
33
+ --dtype bfloat16
34
+ ```
35
+
36
+ ### Launch Benchmarker Independently
37
+
38
+ If you want to launch the SGLang server independently, you can use the following command.
39
+
40
+ ```shell
41
+ # you can launch a server
42
+ python3 -m sglang.launch_server \
43
+ --model meta-llama/Llama-3.1-8B-Instruct \
44
+ --speculative-algorithm EAGLE3 \
45
+ --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
46
+ --speculative-num-steps 3 \
47
+ --speculative-eagle-topk 1 \
48
+ --speculative-num-draft-tokens 4 \
49
+ --mem-fraction-static 0.75 \
50
+ --cuda-graph-max-bs 1 \
51
+ --tp 1 \
52
+ --trust-remote-code \
53
+ --host 0.0.0.0 \
54
+ --port 30000 \
55
+ --dtype bfloat16
56
+ ```
57
+
58
+ Then we can start benchmarking. Note that you should use the same host and port as the one used in the SGLang server. Note that `--skip-launch-server` is required to skip the launch of the SGLang server.
59
+
60
+ ```bash
61
+ python bench_eagle3.py \
62
+ --model-path meta-llama/Llama-3.1-8B-Instruct \
63
+ --port 30000 \
64
+ --config-list 1,3,1,4 \
65
+ --benchmark-list mtbench:5 ceval:5:accountant gsm8k:5 humaneval:5 math500:5 mtbench:5 aime:1 \
66
+ --skip-launch-server
67
+ ```
progress/SpecForge/benchmarks/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Benchmark scripts for speculative decoding evaluation.
3
+ """
progress/SpecForge/benchmarks/bench_eagle3.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Usage:
4
+
5
+ # if you want to run benchmarks directly
6
+ # mtbench:20 means only run 20 samples in the dataset
7
+ python bench_eagle3.py \
8
+ --model meta-llama/Llama-3.1-8B-Instruct \
9
+ --speculative-algorithm EAGLE3 \
10
+ --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
11
+ --port 30000 \
12
+ --config-list 1,0,0,0 1,3,1,4 \
13
+ --benchmark-list mtbench:20 \
14
+ --dtype bfloat16
15
+
16
+
17
+ or if you want run sglang alone.
18
+
19
+ # launch sglang
20
+ python3 -m sglang.launch_server \
21
+ --model meta-llama/Llama-3.1-8B-Instruct \
22
+ --speculative-algorithm EAGLE3 \
23
+ --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
24
+ --speculative-num-steps 3 \
25
+ --speculative-eagle-topk 1 \
26
+ --speculative-num-draft-tokens 4 \
27
+ --mem-fraction-static 0.75 \
28
+ --cuda-graph-max-bs 1 \
29
+ --tp 1 \
30
+ --trust-remote-code \
31
+ --host 0.0.0.0 \
32
+ --port 30000 \
33
+ --dtype bfloat16
34
+
35
+ # then run benchmarks
36
+ python bench_eagle3.py \
37
+ --model-path meta-llama/Llama-3.1-8B-Instruct \
38
+ --port 30000 \
39
+ --config-list 1,0,0,0 \
40
+ --benchmark-list mtbench:80 \
41
+ --dtype bfloat16 \
42
+ --skip-launch-server
43
+ """
44
+ import argparse
45
+ import json
46
+ import os
47
+ import time
48
+ from dataclasses import asdict
49
+ from typing import List
50
+
51
+ import requests
52
+ from benchmarker import BENCHMARKS
53
+ from sglang.srt.server_args import ServerArgs
54
+ from sglang.test.test_utils import kill_process_tree, popen_launch_server
55
+ from sglang.utils import wait_for_server
56
+
57
+
58
+ def parse_args():
59
+ parser = argparse.ArgumentParser()
60
+ sglang_group = parser.add_argument_group("sglang")
61
+ ServerArgs.add_cli_args(sglang_group)
62
+
63
+ # make the follow args a group
64
+ benchmark_group = parser.add_argument_group("benchmark")
65
+ benchmark_group.add_argument(
66
+ "--skip-launch-server", action="store_true", default=False
67
+ )
68
+ benchmark_group.add_argument("--timeout-for-server-launch", type=int, default=600)
69
+ benchmark_group.add_argument("--num-prompts", type=int, default=80)
70
+ benchmark_group.add_argument("--output-dir", type=str, default="./results")
71
+ benchmark_group.add_argument(
72
+ "--config-list", type=str, nargs="+", default=["1,0,0,0", "1,3,1,4"]
73
+ )
74
+ benchmark_group.add_argument(
75
+ "--name",
76
+ type=str,
77
+ default=None,
78
+ help="name of this benchmark run, if provided, will be added to the output file name",
79
+ )
80
+ benchmark_group.add_argument(
81
+ "--benchmark-list",
82
+ type=str,
83
+ nargs="+",
84
+ default=[
85
+ "mtbench:80",
86
+ "gsm8k:200",
87
+ "humaneval:200",
88
+ "math500:200",
89
+ "ceval:200",
90
+ ],
91
+ help=f"The list of benchmarks to run. The format is <benchmark-name>:<num-prompts>:<subset>,<subset>. We support the following benchmarks: {', '.join(BENCHMARKS.benchmarks.keys())}",
92
+ )
93
+ benchmark_group.add_argument(
94
+ "--enable-multi-turn-conversation",
95
+ action="store_true",
96
+ default=False,
97
+ )
98
+ return parser.parse_args()
99
+
100
+
101
+ def launch_sglang_server(
102
+ server_args: ServerArgs,
103
+ base_url: str,
104
+ batch_size: int,
105
+ steps: int,
106
+ topk: int,
107
+ num_draft_tokens: int,
108
+ timeout: int,
109
+ ):
110
+ """
111
+ This function launches the SGLang server with the given server arguments.
112
+ """
113
+ sglang_args: List[str] = []
114
+ if steps > 0:
115
+ sglang_args.extend(
116
+ [
117
+ "--speculative-algorithm",
118
+ "EAGLE3",
119
+ "--speculative-num-steps",
120
+ str(steps),
121
+ "--speculative-eagle-topk",
122
+ str(topk),
123
+ "--speculative-num-draft-tokens",
124
+ str(num_draft_tokens),
125
+ "--speculative-draft-model-path",
126
+ server_args.speculative_draft_model_path,
127
+ ]
128
+ )
129
+
130
+ sglang_args.extend(
131
+ [
132
+ "--cuda-graph-max-bs",
133
+ str(batch_size),
134
+ "--mem-fraction-static",
135
+ str(server_args.mem_fraction_static),
136
+ "--tp-size",
137
+ str(server_args.tp_size),
138
+ "--max-running-requests",
139
+ str(batch_size),
140
+ ]
141
+ )
142
+
143
+ if server_args.trust_remote_code:
144
+ sglang_args.extend(["--trust-remote-code"])
145
+
146
+ if server_args.disable_radix_cache:
147
+ sglang_args.extend(["--disable-radix-cache"])
148
+
149
+ if server_args.ep_size:
150
+ sglang_args.extend(["--ep-size", str(server_args.ep_size)])
151
+
152
+ if server_args.attention_backend:
153
+ sglang_args.extend(["--attention-backend", server_args.attention_backend])
154
+
155
+ if server_args.quantization:
156
+ sglang_args.extend(["--quantization", server_args.quantization])
157
+
158
+ if server_args.dtype:
159
+ sglang_args.extend(["--dtype", server_args.dtype])
160
+
161
+ process = popen_launch_server(
162
+ server_args.model_path,
163
+ base_url,
164
+ timeout=timeout,
165
+ other_args=sglang_args,
166
+ env={
167
+ "SGLANG_RECORD_STEP_TIME": "1",
168
+ "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN": "1",
169
+ **os.environ,
170
+ },
171
+ )
172
+ return process
173
+
174
+
175
+ def send_flush_cache_request(base_url: str):
176
+ requests.post(base_url + "/flush_cache")
177
+
178
+
179
+ def main():
180
+ args = parse_args()
181
+ server_args: ServerArgs = ServerArgs.from_cli_args(args)
182
+ configs = [tuple(map(int, config.split(","))) for config in args.config_list]
183
+
184
+ # split the arg into list of (bench_name, num_prompts)
185
+ benchmark_list = []
186
+ for item in args.benchmark_list:
187
+ splits = item.split(":")
188
+ if len(splits) == 1:
189
+ bench_name = splits[0]
190
+ num_prompts = None
191
+ subset = None
192
+ elif len(splits) == 2:
193
+ bench_name, num_prompts = splits
194
+ subset = None
195
+ elif len(splits) == 3:
196
+ bench_name, num_prompts, subset = splits
197
+ subset = subset.split(",")
198
+ else:
199
+ raise ValueError(f"Invalid benchmark list format: {item}")
200
+ benchmark_list.append((bench_name, num_prompts, subset))
201
+ assert len(benchmark_list) != 0, "the number of benchmark list is 0"
202
+
203
+ base_url = f"http://localhost:{args.port}"
204
+
205
+ results = {}
206
+ results["model"] = server_args.speculative_draft_model_path
207
+
208
+ def run_benchmarks(batch_size: int, steps: int, topk: int, num_draft_tokens: int):
209
+ for benchmark_name, num_prompts, subset in benchmark_list:
210
+ print(
211
+ f"Running benchmark {benchmark_name} with {num_prompts} prompts, batch size {batch_size}, steps {steps}, topk {topk}, num_draft_tokens {num_draft_tokens}, subset {subset}"
212
+ )
213
+ benchmarkder_cls = BENCHMARKS.get(benchmark_name)
214
+ num_prompts = int(num_prompts) if num_prompts is not None else None
215
+ if subset is None:
216
+ benchmarker = benchmarkder_cls(num_samples=num_prompts)
217
+ else:
218
+ benchmarker = benchmarkder_cls(num_samples=num_prompts, subset=subset)
219
+ metrics_list = benchmarker.run(
220
+ host=args.host, port=args.port, batch_size=batch_size
221
+ )
222
+ send_flush_cache_request(base_url)
223
+ if benchmark_name not in results:
224
+ results[benchmark_name] = []
225
+ results[benchmark_name].append(
226
+ dict(
227
+ batch_size=batch_size,
228
+ steps=steps,
229
+ topk=topk,
230
+ num_draft_tokens=num_draft_tokens,
231
+ metrics=[asdict(metric) for metric in metrics_list],
232
+ num_samples=num_prompts,
233
+ )
234
+ )
235
+
236
+ if args.skip_launch_server:
237
+ batch_size = configs[0][0] if len(configs) > 0 else 8
238
+ run_benchmarks(batch_size, None, None, None)
239
+ else:
240
+ # we itearate over each config from args
241
+ for batch_size, steps, topk, num_draft_tokens in configs:
242
+ process = launch_sglang_server(
243
+ server_args,
244
+ base_url,
245
+ batch_size,
246
+ steps,
247
+ topk,
248
+ num_draft_tokens,
249
+ args.timeout_for_server_launch,
250
+ )
251
+ wait_for_server(base_url)
252
+ run_benchmarks(batch_size, steps, topk, num_draft_tokens)
253
+ kill_process_tree(process.pid)
254
+ process.wait()
255
+
256
+ os.makedirs(args.output_dir, exist_ok=True)
257
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
258
+ result_file = os.path.join(
259
+ args.output_dir,
260
+ f"{args.name + '_' if args.name else ''}results_{timestamp}.jsonl",
261
+ )
262
+ with open(result_file, "w") as f:
263
+ json.dump(results, f, indent=4)
264
+ print(f"Results saved to {result_file}")
265
+
266
+
267
+ if __name__ == "__main__":
268
+ main()
progress/SpecForge/benchmarks/benchmarker/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .aime import AIMEBenchmarker
2
+ from .ceval import CEvalBenchmarker
3
+ from .financeqa import FinanceQABenchmarker
4
+ from .gpqa import GPQABenchmarker
5
+ from .gsm8k import GSM8KBenchmarker
6
+ from .humaneval import HumanEvalBenchmarker
7
+ from .livecodebench import LCBBenchmarker
8
+ from .math500 import Math500Benchmarker
9
+ from .mmlu import MMLUBenchmarker
10
+ from .mmstar import MMStarBenchmarker
11
+ from .mtbench import MTBenchBenchmarker
12
+ from .registry import BENCHMARKS
13
+ from .simpleqa import SimpleQABenchmarker
14
+
15
+ __all__ = [
16
+ "BENCHMARKS",
17
+ "AIMEBenchmarker",
18
+ "CEvalBenchmarker",
19
+ "GSM8KBenchmarker",
20
+ "HumanEvalBenchmarker",
21
+ "Math500Benchmarker",
22
+ "MTBenchBenchmarker",
23
+ "MMStarBenchmarker",
24
+ "GPQABenchmarker",
25
+ "FinanceQABenchmarker",
26
+ "MMLUBenchmarker",
27
+ "LCBBenchmarker",
28
+ "SimpleQABenchmarker",
29
+ ]
progress/SpecForge/benchmarks/benchmarker/aime.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AIME benchmark
3
+ """
4
+
5
+ import re
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ from datasets import load_dataset
9
+
10
+ from .base import Benchmarker
11
+ from .registry import BENCHMARKS
12
+ from .utils import create_simple_sgl_function
13
+
14
+
15
+ def extract_aime_answer(output: str) -> Optional[str]:
16
+ """Extract final answer from AIME problem solution.
17
+
18
+ AIME answers are typically integers between 0 and 999, and are usually
19
+ in \boxed{} format.
20
+ """
21
+ # Try to find answer in \boxed{} format
22
+ boxed_pattern = r"\\boxed\{([^}]+)\}"
23
+ match = re.search(boxed_pattern, output)
24
+ if match:
25
+ answer = match.group(1).strip()
26
+ # Extract number from the boxed content
27
+ numbers = re.findall(r"\d+", answer)
28
+ if numbers:
29
+ return numbers[-1] # Take the last number (usually the final answer)
30
+ return answer
31
+
32
+ # Try to find answer in \boxed format (without braces)
33
+ boxed_pattern2 = r"\\boxed\s+(\d+)"
34
+ match = re.search(boxed_pattern2, output)
35
+ if match:
36
+ return match.group(1).strip()
37
+
38
+ # Look for patterns like "The answer is 42" or "Answer: 123"
39
+ answer_patterns = [
40
+ r"(?:answer|Answer|ANSWER)[\s:]+(\d+)",
41
+ r"(?:final\s+answer|Final\s+Answer)[\s:]+(\d+)",
42
+ r"(?:is|equals?|=\s*)(\d+)\s*$",
43
+ ]
44
+ for pattern in answer_patterns:
45
+ matches = re.findall(pattern, output, re.IGNORECASE)
46
+ if matches:
47
+ return matches[-1].strip()
48
+
49
+ # Fallback: extract the last integer in the text
50
+ numbers = re.findall(r"\b(\d+)\b", output)
51
+ if numbers:
52
+ # Filter to reasonable AIME answer range (0-999)
53
+ valid_numbers = [n for n in numbers if 0 <= int(n) <= 999]
54
+ if valid_numbers:
55
+ return valid_numbers[-1]
56
+
57
+ return None
58
+
59
+
60
+ @BENCHMARKS.register("aime")
61
+ class AIMEBenchmarker(Benchmarker):
62
+ """AIME benchmark implementation."""
63
+
64
+ def __init__(self, num_samples: Optional[int] = None):
65
+ super().__init__(num_samples, None)
66
+
67
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
68
+ """Load and preprocess AIME dataset."""
69
+ dataset = load_dataset("Maxwell-Jia/AIME_2024")["train"]
70
+ questions = []
71
+ labels = []
72
+ for idx, q in enumerate(dataset):
73
+ if self.num_samples is not None and idx >= self.num_samples:
74
+ break
75
+
76
+ questions.append({"question": q["Problem"]})
77
+ # Extract answer from Answer field
78
+ answer = None
79
+ if "Answer" in q:
80
+ answer = str(q["Answer"]).strip()
81
+ elif "answer" in q:
82
+ answer = str(q["answer"]).strip()
83
+ labels.append(answer)
84
+ return questions, labels
85
+
86
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
87
+ """Extract answer from model output."""
88
+ return extract_aime_answer(output)
89
+
90
+ def compute_accuracy(
91
+ self, predictions: List[Any], labels: List[Any]
92
+ ) -> Optional[float]:
93
+ """Compute accuracy for AIME by comparing numeric answers."""
94
+ if not labels or len(labels) == 0:
95
+ return None
96
+ if all(label is None for label in labels):
97
+ return None
98
+
99
+ correct = 0
100
+ valid_count = 0
101
+ for pred, label in zip(predictions, labels):
102
+ if label is not None:
103
+ valid_count += 1
104
+ if pred is not None:
105
+ # Normalize answers for comparison
106
+ pred_normalized = str(pred).strip()
107
+ label_normalized = str(label).strip()
108
+ # Try exact match first
109
+ if pred_normalized == label_normalized:
110
+ correct += 1
111
+ else:
112
+ # Try numeric comparison
113
+ try:
114
+ pred_num = int(pred_normalized)
115
+ label_num = int(label_normalized)
116
+ if pred_num == label_num:
117
+ correct += 1
118
+ except ValueError:
119
+ pass
120
+
121
+ return correct / valid_count if valid_count > 0 else 0.0
122
+
123
+ def create_sgl_function(self):
124
+ """Create SGL function for AIME with reasoning prompt."""
125
+ return create_simple_sgl_function(
126
+ function_name="reasoning_gen",
127
+ answer_key="answer",
128
+ user_prefix="\nPlease reason step by step, and put your final answer within \\boxed{}.",
129
+ )
130
+
131
+ def get_max_new_tokens(self) -> int:
132
+ """AIME problems require more tokens."""
133
+ return 32768
progress/SpecForge/benchmarks/benchmarker/base.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base class for benchmark implementations.
3
+ """
4
+
5
+ import time
6
+ from abc import ABC, abstractmethod
7
+ from argparse import Namespace
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple
9
+
10
+ from sglang import set_default_backend
11
+ from sglang.test.test_utils import select_sglang_backend
12
+
13
+ from .utils import compute_metrics
14
+
15
+
16
+ class Benchmarker(ABC):
17
+ """
18
+ Base class for benchmark implementations.
19
+
20
+ Subclasses should implement:
21
+ - load_data(): Load and preprocess dataset
22
+ - create_sgl_function(): Create the SGL function for inference
23
+
24
+ Optional overrides:
25
+ - extract_answer(): Extract answer from model output (if needed)
26
+ - compute_accuracy(): Compute accuracy metric (if applicable)
27
+ - get_answer_keys(): Get list of answer keys for multi-turn conversations
28
+
29
+ Args:
30
+ num_samples: The number of samples to run the benchmark on. If not provided, all questions will be used.
31
+ subset: The subset of the dataset to run the benchmark on. If not provided, all subsets will be used.
32
+ """
33
+
34
+ def __init__(
35
+ self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
36
+ ):
37
+ self.num_samples = num_samples
38
+ self.subset = subset
39
+
40
+ @abstractmethod
41
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[Any]]:
42
+ """
43
+ Load and preprocess the dataset.
44
+
45
+ Returns:
46
+ Tuple of (questions, labels) where:
47
+ - questions: List of question dicts for SGL function
48
+ - labels: List of ground truth labels (can be None if not applicable)
49
+ """
50
+ raise NotImplementedError
51
+
52
+ @abstractmethod
53
+ def create_sgl_function(self) -> Callable:
54
+ """
55
+ Create the SGL function for inference.
56
+
57
+ Returns:
58
+ SGL function decorated with @sgl.function
59
+ """
60
+ raise NotImplementedError
61
+
62
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[Any]:
63
+ """
64
+ Extract answer from model output.
65
+
66
+ Args:
67
+ output: Raw model output string
68
+ label: Optional ground truth label for reference
69
+
70
+ Returns:
71
+ Extracted answer, or None if extraction fails
72
+ """
73
+ return output
74
+
75
+ def compute_accuracy(
76
+ self, predictions: List[Any], labels: List[Any]
77
+ ) -> Optional[float]:
78
+ """
79
+ Compute accuracy metric.
80
+
81
+ Args:
82
+ predictions: List of predicted answers
83
+ labels: List of ground truth labels
84
+
85
+ Returns:
86
+ Accuracy score (0-1), or None if not applicable
87
+ """
88
+ return None
89
+
90
+ def get_answer_keys(self) -> Optional[List[str]]:
91
+ """
92
+ Get list of answer keys for multi-turn conversations.
93
+
94
+ Returns:
95
+ List of answer keys (e.g., ["answer_1", "answer_2"]), or None for single-turn
96
+ """
97
+ return None
98
+
99
+ def get_max_new_tokens(self) -> int:
100
+ """
101
+ Get maximum number of new tokens to generate.
102
+
103
+ Returns:
104
+ Maximum tokens (default: 2048)
105
+ """
106
+ return 2048
107
+
108
+ def run(
109
+ self,
110
+ host: str,
111
+ port: int,
112
+ batch_size: int,
113
+ max_new_tokens: int = None,
114
+ num_runs: int = 1,
115
+ ):
116
+ """
117
+ Run the benchmark evaluation.
118
+
119
+ This method handles the common workflow:
120
+ 1. Initialize backend
121
+ 2. Load data
122
+ 3. Create SGL function
123
+ 4. Run inference loops
124
+ 5. Compute metrics
125
+ 6. Print results
126
+
127
+ Args:
128
+ host (str): The host of the SGLang server
129
+ port (int): The port of the SGLang server
130
+ batch_size (int): The number of prompts to process in parallel
131
+ num_samples (int): The number of samples to run the benchmark on. If not provided, all samples will be used.
132
+ max_new_tokens (int): Maximum number of new tokens to generate, default is 2048
133
+ num_runs (int): The number of times to run this benchmark, default is 1. You can set it to a larger number if you want to get more stable results.
134
+ """
135
+ if not host.startswith(("http://", "https://")):
136
+ host = f"http://{host}"
137
+ # Initialize backend
138
+ sglang_args = Namespace(host=host, port=port, backend="srt-no-parallel")
139
+ set_default_backend(select_sglang_backend(sglang_args))
140
+
141
+ # Load data
142
+ questions, labels = self.load_data()
143
+ if len(questions) == 0:
144
+ print("No valid questions found. Please check the dataset format.")
145
+ return
146
+
147
+ # Create SGL function
148
+ sgl_function = self.create_sgl_function()
149
+
150
+ # Run evaluation loops
151
+ metrics_list = []
152
+ answer_keys = self.get_answer_keys()
153
+ max_new_tokens = max_new_tokens or self.get_max_new_tokens()
154
+
155
+ for _ in range(num_runs):
156
+ tic = time.perf_counter()
157
+ states = sgl_function.run_batch(
158
+ questions,
159
+ temperature=0,
160
+ max_new_tokens=max_new_tokens,
161
+ num_threads=batch_size,
162
+ progress_bar=True,
163
+ )
164
+ latency = time.perf_counter() - tic
165
+
166
+ # Extract predictions
167
+ predictions = []
168
+ primary_answer_key = answer_keys[0] if answer_keys else "answer"
169
+ for i in range(len(states)):
170
+ # Access answer from state object (states[i] supports dict-like access)
171
+ output = states[i][primary_answer_key]
172
+ if isinstance(output, str):
173
+ extracted = self.extract_answer(
174
+ output,
175
+ (labels[i] if labels and i < len(labels) else None),
176
+ )
177
+ else:
178
+ extracted = output
179
+ predictions.append(extracted)
180
+
181
+ # Compute accuracy if applicable
182
+ accuracy = None
183
+ # Check if we have a labels list (even if all labels are None)
184
+ has_labels_list = labels and len(labels) > 0
185
+
186
+ if has_labels_list:
187
+ # Always call compute_accuracy if we have a labels list
188
+ # This allows it to return None, which will be displayed in print_results
189
+ accuracy = self.compute_accuracy(predictions, labels)
190
+ if accuracy is not None:
191
+ valid_count = sum(1 for p in predictions if p is not None)
192
+ if valid_count < len(predictions):
193
+ print(
194
+ f"Warning: {len(predictions) - valid_count} predictions could not be extracted."
195
+ )
196
+
197
+ # Compute performance metrics
198
+ metrics = compute_metrics(
199
+ states,
200
+ latency,
201
+ answer_key=primary_answer_key,
202
+ additional_answer_keys=(
203
+ answer_keys[1:] if answer_keys and len(answer_keys) > 1 else None
204
+ ),
205
+ )
206
+ # Always set accuracy if we have a labels list (even if compute_accuracy returns None)
207
+ # This allows print_results to show None when compute_accuracy returns None
208
+ if has_labels_list:
209
+ metrics.accuracy = (
210
+ accuracy # Can be None if compute_accuracy returns None
211
+ )
212
+ if accuracy is not None:
213
+ metrics.num_valid_predictions = sum(
214
+ 1 for p in predictions if p is not None
215
+ )
216
+
217
+ metrics_list.append(metrics)
218
+ return metrics_list
progress/SpecForge/benchmarks/benchmarker/ceval.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ C-Eval benchmark evaluation script.
3
+ """
4
+
5
+ import re
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ from datasets import concatenate_datasets, load_dataset
9
+
10
+ from .base import Benchmarker
11
+ from .registry import BENCHMARKS
12
+ from .utils import create_simple_sgl_function
13
+
14
+
15
+ def extract_answer(answer_str: str) -> str:
16
+ """Extract the answer choice (A, B, C, D) from the model output."""
17
+ # Try to find the answer in various formats
18
+ answer_str = answer_str.strip().upper()
19
+
20
+ # Direct match for single letter
21
+ match = re.search(r"\b([ABCD])\b", answer_str)
22
+ if match:
23
+ return match.group(1)
24
+
25
+ # Try to find answer in parentheses or brackets
26
+ for pattern in [
27
+ r"\(([ABCD])\)",
28
+ r"\[([ABCD])\]",
29
+ r"答案[::]\s*([ABCD])",
30
+ r"Answer[::]\s*([ABCD])",
31
+ ]:
32
+ match = re.search(pattern, answer_str, re.IGNORECASE)
33
+ if match:
34
+ return match.group(1).upper()
35
+
36
+ # Try to find the first occurrence of A, B, C, or D
37
+ match = re.search(r"([ABCD])", answer_str)
38
+ if match:
39
+ return match.group(1)
40
+
41
+ return None
42
+
43
+
44
+ def format_question(question: str, options: List[str]) -> str:
45
+ """Format the question with options."""
46
+ prompt = question + "\n\n选项:\n"
47
+ for i, option in enumerate(options):
48
+ prompt += f"{chr(65 + i)}. {option}\n"
49
+ prompt += "\n请从A、B、C、D中选择一个答案。"
50
+ return prompt
51
+
52
+
53
+ @BENCHMARKS.register("ceval")
54
+ class CEvalBenchmarker(Benchmarker):
55
+ """C-Eval benchmark implementation."""
56
+
57
+ def __init__(
58
+ self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
59
+ ):
60
+ if subset is None:
61
+ subset = "all"
62
+ super().__init__(num_samples, subset)
63
+
64
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[str]]:
65
+ """Load and preprocess C-Eval dataset."""
66
+ all_configs = [
67
+ "accountant",
68
+ "advanced_mathematics",
69
+ "art_studies",
70
+ "basic_medicine",
71
+ "business_administration",
72
+ "chinese_language_and_literature",
73
+ "civil_servant",
74
+ "clinical_medicine",
75
+ "college_chemistry",
76
+ "college_economics",
77
+ "college_physics",
78
+ "college_programming",
79
+ "computer_architecture",
80
+ "computer_network",
81
+ "discrete_mathematics",
82
+ "education_science",
83
+ "electrical_engineer",
84
+ "environmental_impact_assessment_engineer",
85
+ "fire_engineer",
86
+ "high_school_biology",
87
+ "high_school_chemistry",
88
+ "high_school_chinese",
89
+ "high_school_geography",
90
+ "high_school_history",
91
+ "high_school_mathematics",
92
+ "high_school_physics",
93
+ "high_school_politics",
94
+ "ideological_and_moral_cultivation",
95
+ "law",
96
+ "legal_professional",
97
+ "logic",
98
+ "mao_zedong_thought",
99
+ "marxism",
100
+ "metrology_engineer",
101
+ "middle_school_biology",
102
+ "middle_school_chemistry",
103
+ "middle_school_geography",
104
+ "middle_school_history",
105
+ "middle_school_mathematics",
106
+ "middle_school_physics",
107
+ "middle_school_politics",
108
+ "modern_chinese_history",
109
+ "operating_system",
110
+ "physician",
111
+ "plant_protection",
112
+ "probability_and_statistics",
113
+ "professional_tour_guide",
114
+ "sports_science",
115
+ "tax_accountant",
116
+ "teacher_qualification",
117
+ "urban_and_rural_planner",
118
+ "veterinary_medicine",
119
+ ]
120
+
121
+ # Select configs to load
122
+ if self.subset == "all":
123
+ configs_to_load = all_configs
124
+ else:
125
+ for subset in self.subset:
126
+ assert (
127
+ subset in all_configs
128
+ ), f"Subset {subset} not found in C-Eval dataset"
129
+ configs_to_load = self.subset
130
+
131
+ # Load datasets
132
+ try:
133
+ datasets = []
134
+ for config in configs_to_load:
135
+ try:
136
+ ds = load_dataset("ceval/ceval-exam", name=config, split="test")
137
+ datasets.append(ds)
138
+ print(f"Loaded config '{config}' with {len(ds)} samples")
139
+ except Exception as e:
140
+ print(f"Warning: Failed to load config '{config}': {e}")
141
+ if len(datasets) == 0:
142
+ raise ValueError("No configs could be loaded")
143
+ dataset = concatenate_datasets(datasets)
144
+ print(
145
+ f"Successfully loaded C-Eval dataset with all configs (total: {len(dataset)} samples)"
146
+ )
147
+ except Exception as e:
148
+ print(e)
149
+ print(f"Failed to load C-Eval dataset from 'ceval/ceval-exam': {e}")
150
+ print("Please ensure the dataset is available or install it manually.")
151
+ print("You can try: pip install datasets")
152
+ print("Or download from: https://huggingface.co/datasets/ceval/ceval-exam")
153
+ return [], []
154
+
155
+ # Process questions
156
+ questions = []
157
+ labels = []
158
+ for idx, item in enumerate(dataset):
159
+ if self.num_samples is not None and idx >= self.num_samples:
160
+ break
161
+
162
+ # Handle different dataset formats
163
+ question_text = None
164
+ if "question" in item:
165
+ question_text = item["question"]
166
+ elif "inputs" in item:
167
+ question_text = item["inputs"]
168
+ elif "problem" in item:
169
+ question_text = item["problem"]
170
+ elif "content" in item:
171
+ question_text = item["content"]
172
+
173
+ if not question_text:
174
+ continue
175
+
176
+ # Get options - C-Eval typically has options as a list or dict
177
+ options = None
178
+ if "options" in item:
179
+ options = item["options"]
180
+ if isinstance(options, dict):
181
+ # Convert dict to list in order A, B, C, D
182
+ options = [
183
+ options.get("A", ""),
184
+ options.get("B", ""),
185
+ options.get("C", ""),
186
+ options.get("D", ""),
187
+ ]
188
+ elif isinstance(options, list):
189
+ # Ensure we have 4 options
190
+ while len(options) < 4:
191
+ options.append("")
192
+ elif "choices" in item:
193
+ options = item["choices"]
194
+ if isinstance(options, dict):
195
+ options = [
196
+ options.get("A", ""),
197
+ options.get("B", ""),
198
+ options.get("C", ""),
199
+ options.get("D", ""),
200
+ ]
201
+ else:
202
+ # Try to construct options from A, B, C, D fields
203
+ options = [
204
+ item.get("A", item.get("option_A", "")),
205
+ item.get("B", item.get("option_B", "")),
206
+ item.get("C", item.get("option_C", "")),
207
+ item.get("D", item.get("option_D", "")),
208
+ ]
209
+
210
+ # Filter out empty options
211
+ if options:
212
+ options = [str(opt).strip() for opt in options if opt]
213
+ if len(options) < 2: # Need at least 2 options
214
+ continue
215
+ else:
216
+ continue
217
+
218
+ # Get answer
219
+ answer = None
220
+ if "answer" in item:
221
+ answer = str(item["answer"]).upper().strip()
222
+ elif "target" in item:
223
+ answer = str(item["target"]).upper().strip()
224
+ elif "label" in item:
225
+ answer = str(item["label"]).upper().strip()
226
+ elif "correct" in item:
227
+ answer = str(item["correct"]).upper().strip()
228
+
229
+ # Validate answer
230
+ if answer and answer in ["A", "B", "C", "D"]:
231
+ # Format question
232
+ formatted_question = format_question(question_text, options)
233
+ questions.append({"question": formatted_question})
234
+ labels.append(answer)
235
+
236
+ if len(questions) == 0:
237
+ print("No valid questions found. Please check the dataset format.")
238
+ print(
239
+ "Sample item keys:",
240
+ list(dataset[0].keys()) if len(dataset) > 0 else "No items",
241
+ )
242
+ return [], []
243
+
244
+ return questions, labels
245
+
246
+ def create_sgl_function(self):
247
+ """Create SGL function for C-Eval."""
248
+ return create_simple_sgl_function(
249
+ function_name="get_ceval_answer",
250
+ answer_key="answer",
251
+ max_tokens=self.get_max_new_tokens(),
252
+ )
253
+
254
+ def extract_answer(self, output: str, label: Any = None) -> str:
255
+ """Extract answer choice from model output."""
256
+ return extract_answer(output)
257
+
258
+ def compute_accuracy(self, predictions: List[str], labels: List[str]) -> float:
259
+ """Compute accuracy metric."""
260
+ correct = 0
261
+ valid_count = 0
262
+ for i in range(len(predictions)):
263
+ if predictions[i] is not None: # Only count valid predictions
264
+ valid_count += 1
265
+ if predictions[i] == labels[i]:
266
+ correct += 1
267
+ return correct / valid_count if valid_count > 0 else 0.0
progress/SpecForge/benchmarks/benchmarker/financeqa.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ from datasets import load_dataset
4
+
5
+ from .base import Benchmarker
6
+ from .registry import BENCHMARKS
7
+ from .utils import create_simple_sgl_function
8
+
9
+ QUESTION_PROMPT = """
10
+ Given the following context:
11
+
12
+ {context}
13
+
14
+ Can you answer the following question?
15
+
16
+ {question}
17
+ """.strip()
18
+
19
+
20
+ def generate_question(row: Dict[str, Any]) -> str:
21
+ if row["context"] is None:
22
+ return row["question"].strip()
23
+ else:
24
+ question = QUESTION_PROMPT.format(
25
+ context=row["context"].strip(),
26
+ question=row["question"].strip(),
27
+ )
28
+ return question
29
+
30
+
31
+ @BENCHMARKS.register("financeqa")
32
+ class FinanceQABenchmarker(Benchmarker):
33
+ """FinanceQA benchmark implementation."""
34
+
35
+ def __init__(self, num_samples: Optional[int] = None):
36
+ super().__init__(num_samples, None)
37
+
38
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
39
+ """Load and preprocess FinanceQA dataset."""
40
+ # Read data
41
+ ds = load_dataset("AfterQuery/FinanceQA")["test"]
42
+
43
+ questions = []
44
+ labels = []
45
+ for i in range((len(ds))):
46
+ if self.num_samples is not None and i >= self.num_samples:
47
+ break
48
+
49
+ question_text = generate_question(ds[i])
50
+ questions.append({"question": question_text})
51
+ labels.append(None)
52
+ return questions, labels
53
+
54
+ def create_sgl_function(self):
55
+ return create_simple_sgl_function(
56
+ function_name="get_financeqa_answer",
57
+ answer_key="answer",
58
+ max_tokens=self.get_max_new_tokens(),
59
+ )
progress/SpecForge/benchmarks/benchmarker/gpqa.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ from datasets import load_dataset
5
+
6
+ from .base import Benchmarker
7
+ from .registry import BENCHMARKS
8
+ from .utils import create_simple_sgl_function
9
+
10
+ GPQA_QUERY_TEMPLATE = """
11
+ Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
12
+
13
+ {Question}
14
+
15
+ A) {A}
16
+ B) {B}
17
+ C) {C}
18
+ D) {D}
19
+ """.strip()
20
+
21
+
22
+ def generate_question(row: Dict[str, Any]) -> str:
23
+ gold_index = random.randint(0, 3)
24
+ choices = [
25
+ row["Incorrect Answer 1"],
26
+ row["Incorrect Answer 2"],
27
+ row["Incorrect Answer 3"],
28
+ ]
29
+ choices.insert(gold_index, row["Correct Answer"])
30
+
31
+ question = GPQA_QUERY_TEMPLATE.format(
32
+ Question=row["Question"].strip(),
33
+ A=choices[0].strip(),
34
+ B=choices[1].strip(),
35
+ C=choices[2].strip(),
36
+ D=choices[3].strip(),
37
+ )
38
+
39
+ # 0 means A, 1 means B, 2 means C, 3 means D
40
+ answer = ["A", "B", "C", "D"][gold_index]
41
+ return question, answer
42
+
43
+
44
+ @BENCHMARKS.register("gpqa")
45
+ class GPQABenchmarker(Benchmarker):
46
+ """GPQA benchmark implementation."""
47
+
48
+ def __init__(self, num_samples: Optional[int] = None):
49
+ super().__init__(num_samples, None)
50
+
51
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
52
+ """Load and preprocess GPQA dataset."""
53
+ # Read data
54
+ ds = load_dataset("Idavidrein/gpqa", "gpqa_main")["train"]
55
+
56
+ questions = []
57
+ labels = []
58
+ for i in range((len(ds))):
59
+ if self.num_samples is not None and i >= self.num_samples:
60
+ break
61
+
62
+ question_text, answer = generate_question(ds[i])
63
+ questions.append({"question": question_text})
64
+ labels.append(answer)
65
+ return questions, labels
66
+
67
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
68
+ if "Answer: " not in output:
69
+ return None
70
+ return output.split("Answer: ")[1].strip()
71
+
72
+ def compute_accuracy(
73
+ self, predictions: List[Any], labels: List[Any]
74
+ ) -> Optional[float]:
75
+ if not labels or len(labels) == 0:
76
+ return None
77
+ correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
78
+ return correct / len(labels) if len(labels) > 0 else 0.0
79
+
80
+ def create_sgl_function(self):
81
+ return create_simple_sgl_function(
82
+ function_name="get_gpqa_mcq_answer",
83
+ answer_key="answer",
84
+ max_tokens=self.get_max_new_tokens(),
85
+ )
progress/SpecForge/benchmarks/benchmarker/gsm8k.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GSM8K benchmark evaluation script.
3
+ """
4
+
5
+ import ast
6
+ import re
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ from sglang.utils import download_and_cache_file, read_jsonl
10
+
11
+ from .base import Benchmarker
12
+ from .registry import BENCHMARKS
13
+ from .utils import create_few_shot_sgl_function
14
+
15
+ INVALID = -9999999
16
+
17
+
18
+ def get_one_example(lines: List[Dict], i: int, include_answer: bool) -> str:
19
+ """Format a single example."""
20
+ ret = "Question: " + lines[i]["question"] + "\nAnswer:"
21
+ if include_answer:
22
+ ret += " " + lines[i]["answer"]
23
+ return ret
24
+
25
+
26
+ def get_few_shot_examples(lines: List[Dict], k: int) -> str:
27
+ """Get few-shot examples as a string."""
28
+ ret = ""
29
+ for i in range(k):
30
+ ret += get_one_example(lines, i, True) + "\n\n"
31
+ return ret
32
+
33
+
34
+ def get_answer_value(answer_str: str) -> int:
35
+ """Extract numeric answer from model output."""
36
+ answer_str = answer_str.replace(",", "")
37
+ numbers = re.findall(r"\d+", answer_str)
38
+ if len(numbers) < 1:
39
+ return INVALID
40
+ try:
41
+ return ast.literal_eval(numbers[-1])
42
+ except SyntaxError:
43
+ return INVALID
44
+
45
+
46
+ @BENCHMARKS.register("gsm8k")
47
+ class GSM8KBenchmarker(Benchmarker):
48
+ """GSM8K benchmark implementation."""
49
+
50
+ def __init__(self, num_samples: Optional[int] = None):
51
+ super().__init__(num_samples, None)
52
+
53
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
54
+ """Load and preprocess GSM8K dataset."""
55
+ # Read data
56
+ url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
57
+ data_path = download_and_cache_file(url)
58
+ lines = list(read_jsonl(data_path))
59
+
60
+ # Construct prompts
61
+ few_shot_examples = get_few_shot_examples(lines, 5)
62
+
63
+ questions = []
64
+ labels = []
65
+ for i in range((len(lines))):
66
+ if self.num_samples is not None and i >= self.num_samples:
67
+ break
68
+
69
+ question_text = get_one_example(lines, i, False)
70
+ questions.append({"question": question_text})
71
+ labels.append(get_answer_value(lines[i]["answer"]))
72
+
73
+ # Store few_shot_examples for use in create_sgl_function
74
+ self.few_shot_examples = few_shot_examples
75
+
76
+ assert all(l != INVALID for l in labels), "Some labels are invalid"
77
+ return questions, labels
78
+
79
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
80
+ """Extract numeric answer from model output."""
81
+ return get_answer_value(output)
82
+
83
+ def compute_accuracy(
84
+ self, predictions: List[Any], labels: List[Any]
85
+ ) -> Optional[float]:
86
+ """Compute accuracy for GSM8K by comparing numeric answers."""
87
+ if not labels or len(labels) == 0:
88
+ return None
89
+ correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
90
+ return correct / len(labels) if len(labels) > 0 else 0.0
91
+
92
+ def create_sgl_function(self):
93
+ """Create SGL function for GSM8K with few-shot examples."""
94
+ return create_few_shot_sgl_function(
95
+ few_shot_examples=self.few_shot_examples,
96
+ function_name="few_shot_gsm8k",
97
+ answer_key="answer",
98
+ stop=["Question", "Assistant:", "<|separator|>"],
99
+ )
progress/SpecForge/benchmarks/benchmarker/humaneval.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HumanEval benchmark evaluation script.
3
+ """
4
+
5
+ import re
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ from datasets import load_dataset
9
+
10
+ from .base import Benchmarker
11
+ from .registry import BENCHMARKS
12
+ from .utils import create_simple_sgl_function
13
+
14
+
15
+ def extract_code_from_output(output: str) -> Optional[str]:
16
+ """Extract Python code from model output.
17
+
18
+ Tries to extract code blocks or function definitions.
19
+ """
20
+ # Try to find code in markdown code blocks
21
+ code_block_pattern = r"```(?:python)?\n(.*?)```"
22
+ match = re.search(code_block_pattern, output, re.DOTALL)
23
+ if match:
24
+ return match.group(1).strip()
25
+
26
+ # Try to find function definition (common in HumanEval)
27
+ # Look for "def " followed by code until the next def or end of string
28
+ def_pattern = r"(def\s+\w+\([^)]*\):.*?)(?=\n\ndef\s+|\Z)"
29
+ match = re.search(def_pattern, output, re.DOTALL)
30
+ if match:
31
+ return match.group(1).strip()
32
+
33
+ # Fallback: return the output as-is (might already be code)
34
+ return output.strip() if output.strip() else None
35
+
36
+
37
+ def check_code_passes_tests(code: str, test_code: str, entry_point: str) -> bool:
38
+ """Check if generated code passes the test cases.
39
+
40
+ This is a simplified version. For full evaluation, use the official
41
+ HumanEval evaluation framework.
42
+
43
+ HumanEval test code typically contains assertions that will raise
44
+ AssertionError if the code doesn't pass. If execution completes without
45
+ exceptions, the tests pass.
46
+ """
47
+ try:
48
+ # Create a safe execution environment
49
+ namespace = {}
50
+ # Execute the code (function definition)
51
+ exec(code, namespace)
52
+ # Execute the test code (which contains assertions)
53
+ # If no exception is raised, the tests pass
54
+ exec(test_code, namespace)
55
+ return True
56
+ except AssertionError:
57
+ # Assertion failed - test didn't pass
58
+ return False
59
+ except Exception:
60
+ # Any other exception (syntax error, runtime error, etc.) means test failed
61
+ return False
62
+
63
+
64
+ @BENCHMARKS.register("humaneval")
65
+ class HumanEvalBenchmarker(Benchmarker):
66
+ """HumanEval benchmark implementation."""
67
+
68
+ def __init__(self, num_samples: Optional[int] = None):
69
+ """Initialize benchmark and store test cases."""
70
+ super().__init__(num_samples, None)
71
+ self.test_cases = []
72
+ self.entry_points = []
73
+
74
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[Dict[str, str]]]]:
75
+ """Load and preprocess HumanEval dataset."""
76
+ dataset = load_dataset("openai/openai_humaneval")["test"]
77
+ questions = []
78
+ labels = []
79
+ self.test_cases = []
80
+ self.entry_points = []
81
+
82
+ for idx, q in enumerate(dataset):
83
+ if self.num_samples is not None and idx >= self.num_samples:
84
+ break
85
+
86
+ questions.append({"question": q["prompt"]})
87
+
88
+ # Store test case and entry point for evaluation
89
+ test_code = q.get("test", "")
90
+ entry_point = q.get("entry_point", "")
91
+ self.test_cases.append(test_code)
92
+ self.entry_points.append(entry_point)
93
+
94
+ # Store canonical solution as reference (optional, for comparison)
95
+ canonical_solution = q.get("canonical_solution", "")
96
+ labels.append(
97
+ {
98
+ "test": test_code,
99
+ "entry_point": entry_point,
100
+ "canonical_solution": canonical_solution,
101
+ }
102
+ )
103
+
104
+ return questions, labels
105
+
106
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
107
+ """Extract code from model output."""
108
+ return extract_code_from_output(output)
109
+
110
+ def compute_accuracy(
111
+ self, predictions: List[Any], labels: List[Any]
112
+ ) -> Optional[float]:
113
+ """Compute accuracy for HumanEval by checking if code passes tests.
114
+
115
+ Note: This is a simplified evaluation. For official pass@k metrics,
116
+ use the HumanEval evaluation framework.
117
+ """
118
+ if not labels or len(labels) == 0:
119
+ return None
120
+ if all(label is None for label in labels):
121
+ return None
122
+
123
+ correct = 0
124
+ valid_count = 0
125
+
126
+ for i, (pred, label) in enumerate(zip(predictions, labels)):
127
+ if label is not None and isinstance(label, dict):
128
+ valid_count += 1
129
+ if pred is not None:
130
+ try:
131
+ # Get the prompt (function signature and docstring)
132
+ prompt = self.questions[i]["question"]
133
+ entry_point = label.get("entry_point", "")
134
+
135
+ # The prompt contains the function signature (e.g., "def function_name(...):")
136
+ # The generated code might be:
137
+ # 1. Just the function body (what we want) - need to combine with prompt
138
+ # 2. The complete function including signature - use as-is
139
+ # 3. Code in markdown blocks - already extracted by extract_code_from_output
140
+
141
+ pred_str = str(pred).strip()
142
+
143
+ # Check if pred already contains a complete function definition
144
+ # (starts with "def " and contains the entry_point function name)
145
+ if pred_str.startswith("def ") and entry_point:
146
+ # Check if this is the same function (by name)
147
+ func_name_match = re.match(r"def\s+(\w+)\s*\(", pred_str)
148
+ if (
149
+ func_name_match
150
+ and func_name_match.group(1) == entry_point
151
+ ):
152
+ # Generated code includes complete function, use it as-is
153
+ full_code = pred_str
154
+ else:
155
+ # Different function or no match, combine with prompt
156
+ full_code = prompt + "\n" + pred_str
157
+ elif pred_str.startswith("def "):
158
+ # Has function definition but we can't verify entry_point, use as-is
159
+ full_code = pred_str
160
+ else:
161
+ # Generated code is just the body, combine with prompt
162
+ full_code = prompt + "\n" + pred_str
163
+
164
+ # Check if code passes tests
165
+ test_code = label.get("test", "")
166
+
167
+ if test_code and check_code_passes_tests(
168
+ full_code, test_code, entry_point
169
+ ):
170
+ correct += 1
171
+ except Exception as e:
172
+ # If evaluation fails, consider it incorrect
173
+ # Uncomment for debugging: print(f"Error evaluating code {i}: {e}")
174
+ pass
175
+
176
+ return correct / valid_count if valid_count > 0 else 0.0
177
+
178
+ def create_sgl_function(self):
179
+ """Create SGL function for HumanEval."""
180
+ return create_simple_sgl_function(
181
+ function_name="get_humaneval_answer",
182
+ answer_key="answer",
183
+ max_tokens=self.get_max_new_tokens(),
184
+ )
185
+
186
+ def get_max_new_tokens(self) -> int:
187
+ """HumanEval code generation requires more tokens."""
188
+ return 1024
progress/SpecForge/benchmarks/benchmarker/livecodebench.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GSM8K benchmark evaluation script.
3
+ """
4
+
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ from datasets import load_dataset
8
+
9
+ from .base import Benchmarker
10
+ from .registry import BENCHMARKS
11
+ from .utils import create_simple_sgl_function
12
+
13
+
14
+ def generate_question(row: Dict[str, Any]) -> str:
15
+ question = row["question_content"].strip()
16
+ return question
17
+
18
+
19
+ @BENCHMARKS.register("livecodebench")
20
+ class LCBBenchmarker(Benchmarker):
21
+ """LiveCodeBench benchmark implementation."""
22
+
23
+ def __init__(self, num_samples: Optional[int] = None):
24
+ super().__init__(num_samples, None)
25
+
26
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
27
+ # Read data
28
+ ds = load_dataset("livecodebench/code_generation")["test"]
29
+
30
+ questions = []
31
+ labels = []
32
+ for i in range((len(ds))):
33
+ if self.num_samples is not None and i >= self.num_samples:
34
+ break
35
+
36
+ question_text = generate_question(ds[i])
37
+ questions.append({"question": question_text})
38
+ labels.append(None)
39
+ return questions, labels
40
+
41
+ def create_sgl_function(self):
42
+ return create_simple_sgl_function(
43
+ function_name="get_livecodebench_answer",
44
+ answer_key="answer",
45
+ max_tokens=self.get_max_new_tokens(),
46
+ )
progress/SpecForge/benchmarks/benchmarker/math500.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MATH-500 benchmark evaluation script.
3
+ """
4
+
5
+ import re
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ from datasets import load_dataset
9
+
10
+ from .base import Benchmarker
11
+ from .registry import BENCHMARKS
12
+ from .utils import create_simple_sgl_function
13
+
14
+
15
+ def extract_math_answer(output: str) -> Optional[str]:
16
+ """Extract final answer from math problem solution.
17
+
18
+ Tries to extract answer from \boxed{} format first, then looks for
19
+ the last number in the output.
20
+ """
21
+ # Try to find answer in \boxed{} format
22
+ boxed_pattern = r"\\boxed\{([^}]+)\}"
23
+ match = re.search(boxed_pattern, output)
24
+ if match:
25
+ return match.group(1).strip()
26
+
27
+ # Try to find answer in \boxed format (without braces)
28
+ boxed_pattern2 = r"\\boxed\s+([^\s]+)"
29
+ match = re.search(boxed_pattern2, output)
30
+ if match:
31
+ return match.group(1).strip()
32
+
33
+ # Try to find the last number (could be integer or decimal)
34
+ # Look for patterns like "The answer is 42" or "Answer: 3.14"
35
+ answer_patterns = [
36
+ r"(?:answer|Answer|ANSWER)[\s:]+([-+]?\d*\.?\d+)",
37
+ r"(?:is|equals?|=\s*)([-+]?\d*\.?\d+)\s*$",
38
+ ]
39
+ for pattern in answer_patterns:
40
+ matches = re.findall(pattern, output, re.IGNORECASE)
41
+ if matches:
42
+ return matches[-1].strip()
43
+
44
+ # Fallback: extract the last number in the text
45
+ numbers = re.findall(r"[-+]?\d*\.?\d+", output)
46
+ if numbers:
47
+ return numbers[-1]
48
+
49
+ return None
50
+
51
+
52
+ @BENCHMARKS.register("math500")
53
+ class Math500Benchmarker(Benchmarker):
54
+ """MATH-500 benchmark implementation."""
55
+
56
+ def __init__(self, num_samples: Optional[int] = None):
57
+ super().__init__(num_samples, None)
58
+
59
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
60
+ """Load and preprocess MATH-500 dataset."""
61
+ dataset = load_dataset("HuggingFaceH4/MATH-500")["test"]
62
+ questions = []
63
+ labels = []
64
+ for idx, q in enumerate(dataset):
65
+ if self.num_samples is not None and idx >= self.num_samples:
66
+ break
67
+
68
+ questions.append({"question": q["problem"]})
69
+ # Extract answer from solution or answer field
70
+ answer = None
71
+ if "answer" in q:
72
+ answer = str(q["answer"]).strip()
73
+ elif "solution" in q:
74
+ # Try to extract from solution
75
+ answer = extract_math_answer(q["solution"])
76
+ labels.append(answer)
77
+ return questions, labels
78
+
79
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
80
+ """Extract answer from model output."""
81
+ return extract_math_answer(output)
82
+
83
+ def compute_accuracy(
84
+ self, predictions: List[Any], labels: List[Any]
85
+ ) -> Optional[float]:
86
+ """Compute accuracy for MATH-500 by comparing answers."""
87
+ if not labels or len(labels) == 0:
88
+ return None
89
+ if all(label is None for label in labels):
90
+ return None
91
+
92
+ correct = 0
93
+ valid_count = 0
94
+ for pred, label in zip(predictions, labels):
95
+ if label is not None:
96
+ valid_count += 1
97
+ if pred is not None:
98
+ # Normalize answers for comparison (remove whitespace, handle different formats)
99
+ pred_normalized = str(pred).strip().lower()
100
+ label_normalized = str(label).strip().lower()
101
+ # Try exact match first
102
+ if pred_normalized == label_normalized:
103
+ correct += 1
104
+ else:
105
+ # Try numeric comparison if both are numbers
106
+ try:
107
+ pred_num = float(pred_normalized)
108
+ label_num = float(label_normalized)
109
+ if abs(pred_num - label_num) < 1e-6:
110
+ correct += 1
111
+ except ValueError:
112
+ pass
113
+
114
+ return correct / valid_count if valid_count > 0 else 0.0
115
+
116
+ def create_sgl_function(self):
117
+ """Create SGL function for MATH-500."""
118
+ return create_simple_sgl_function(
119
+ function_name="get_math500_answer",
120
+ answer_key="answer",
121
+ max_tokens=self.get_max_new_tokens(),
122
+ )
progress/SpecForge/benchmarks/benchmarker/mmlu.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ from datasets import load_dataset
4
+
5
+ from .base import Benchmarker
6
+ from .registry import BENCHMARKS
7
+ from .utils import create_simple_sgl_function
8
+
9
+ GPQA_QUERY_TEMPLATE = """
10
+ Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
11
+
12
+ {Question}
13
+
14
+ A) {A}
15
+ B) {B}
16
+ C) {C}
17
+ D) {D}
18
+ """.strip()
19
+
20
+
21
+ def generate_question(row: Dict[str, Any]) -> str:
22
+ choices = row["choices"]
23
+ question = GPQA_QUERY_TEMPLATE.format(
24
+ Question=row["question"].strip(),
25
+ A=choices[0].strip(),
26
+ B=choices[1].strip(),
27
+ C=choices[2].strip(),
28
+ D=choices[3].strip(),
29
+ )
30
+
31
+ # 0 means A, 1 means B, 2 means C, 3 means D
32
+ answer = ["A", "B", "C", "D"][row["answer"]]
33
+ print(answer)
34
+ return question, answer
35
+
36
+
37
+ @BENCHMARKS.register("mmlu")
38
+ class MMLUBenchmarker(Benchmarker):
39
+ """MMLU benchmark implementation."""
40
+
41
+ def __init__(
42
+ self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
43
+ ):
44
+ if subset is None:
45
+ subset = ["all"]
46
+ super().__init__(num_samples, subset)
47
+
48
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
49
+ # Read data
50
+ questions = []
51
+ labels = []
52
+
53
+ for subset in self.subset:
54
+ ds = load_dataset("cais/mmlu", subset)["test"]
55
+ for i in range((len(ds))):
56
+ if self.num_samples is not None and i >= self.num_samples:
57
+ break
58
+
59
+ question_text, answer = generate_question(ds[i])
60
+ questions.append({"question": question_text})
61
+ labels.append(answer)
62
+ return questions, labels
63
+
64
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
65
+ if "Answer: " not in output:
66
+ return None
67
+ return output.split("Answer: ")[1].strip()
68
+
69
+ def compute_accuracy(
70
+ self, predictions: List[Any], labels: List[Any]
71
+ ) -> Optional[float]:
72
+ if not labels or len(labels) == 0:
73
+ return None
74
+ correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
75
+ return correct / len(labels) if len(labels) > 0 else 0.0
76
+
77
+ def create_sgl_function(self):
78
+ return create_simple_sgl_function(
79
+ function_name="get_mmlu_answer",
80
+ answer_key="answer",
81
+ max_tokens=self.get_max_new_tokens(),
82
+ )
progress/SpecForge/benchmarks/benchmarker/mmstar.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MMStar benchmark evaluation script.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import shutil
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ from datasets import load_dataset
11
+
12
+ from .base import Benchmarker
13
+ from .registry import BENCHMARKS
14
+ from .utils import create_image_sgl_function
15
+
16
+
17
+ def extract_mmstar_answer(
18
+ output: str, options: Optional[List[str]] = None
19
+ ) -> Optional[str]:
20
+ """Extract answer from MMStar model output.
21
+
22
+ MMStar questions typically have multiple choice options (A, B, C, D, etc.)
23
+ """
24
+ output_upper = output.strip().upper()
25
+
26
+ # Try to find answer choice (A, B, C, D, etc.)
27
+ # Direct match for single letter
28
+ match = re.search(r"\b([A-Z])\b", output_upper)
29
+ if match:
30
+ letter = match.group(1)
31
+ if options and len(options) > 0:
32
+ # Validate that the letter is within valid range
33
+ max_option = chr(64 + len(options)) # 'A' + (len-1)
34
+ if "A" <= letter <= max_option:
35
+ return letter
36
+ else:
37
+ # Assume A-D are valid
38
+ if "A" <= letter <= "D":
39
+ return letter
40
+
41
+ # Try to find answer in parentheses or brackets
42
+ for pattern in [
43
+ r"\(([A-Z])\)",
44
+ r"\[([A-Z])\]",
45
+ r"答案[::]\s*([A-Z])",
46
+ r"Answer[::]\s*([A-Z])",
47
+ r"选择[::]\s*([A-Z])",
48
+ ]:
49
+ match = re.search(pattern, output_upper)
50
+ if match:
51
+ letter = match.group(1)
52
+ if options and len(options) > 0:
53
+ max_option = chr(64 + len(options))
54
+ if "A" <= letter <= max_option:
55
+ return letter
56
+ elif "A" <= letter <= "D":
57
+ return letter
58
+
59
+ return None
60
+
61
+
62
+ @BENCHMARKS.register("mmstar")
63
+ class MMStarBenchmarker(Benchmarker):
64
+ """MMStar benchmark implementation."""
65
+
66
+ def __init__(self, num_samples: Optional[int] = None):
67
+ super().__init__(num_samples, None)
68
+ """Initialize benchmark and set up cache directory."""
69
+ self.cache_dir = None
70
+ self.options_list = [] # Store options for each question
71
+
72
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
73
+ """Load and preprocess MMStar dataset."""
74
+ self.cache_dir = os.path.join(".cache", "mmstar_specforge")
75
+ image_dir = os.path.join(self.cache_dir, "images")
76
+ os.makedirs(self.cache_dir, exist_ok=True)
77
+ os.makedirs(image_dir, exist_ok=True)
78
+ print(f"Created temporary image directory: {self.cache_dir}")
79
+
80
+ dataset = load_dataset("Lin-Chen/MMStar")["val"]
81
+ questions = []
82
+ labels = []
83
+ self.options_list = []
84
+
85
+ for idx, q in enumerate(dataset):
86
+ if self.num_samples is not None and idx >= self.num_samples:
87
+ break
88
+
89
+ image = q["image"]
90
+ image_path = os.path.join(self.cache_dir, q["meta_info"]["image_path"])
91
+ image.convert("RGB").save(image_path, "JPEG")
92
+
93
+ # Extract question and options
94
+ question_full = q["question"]
95
+ if "Options:" in question_full:
96
+ question_text, options_text = question_full.split("Options:", 1)
97
+ question_text = question_text.strip()
98
+ # Parse options (typically A. option1 B. option2 etc.)
99
+ options = []
100
+ for line in options_text.strip().split("\n"):
101
+ line = line.strip()
102
+ if line and re.match(r"^[A-Z]\.", line):
103
+ option_text = re.sub(r"^[A-Z]\.\s*", "", line).strip()
104
+ options.append(option_text)
105
+ self.options_list.append(options)
106
+ else:
107
+ question_text = question_full.strip()
108
+ self.options_list.append([])
109
+
110
+ item = {
111
+ "image_path": image_path,
112
+ "question": question_text,
113
+ }
114
+ questions.append(item)
115
+
116
+ # Extract ground truth answer
117
+ answer = None
118
+ if "answer" in q:
119
+ answer = str(q["answer"]).strip().upper()
120
+ elif "correct_answer" in q:
121
+ answer = str(q["correct_answer"]).strip().upper()
122
+ elif "ground_truth" in q:
123
+ answer = str(q["ground_truth"]).strip().upper()
124
+
125
+ # Validate answer is a valid option letter
126
+ if answer and len(answer) == 1 and "A" <= answer <= "Z":
127
+ if self.options_list[-1]:
128
+ max_option = chr(64 + len(self.options_list[-1]))
129
+ if answer <= max_option:
130
+ labels.append(answer)
131
+ else:
132
+ labels.append(None)
133
+ else:
134
+ labels.append(answer)
135
+ else:
136
+ labels.append(None)
137
+
138
+ return questions, labels
139
+
140
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
141
+ """Extract answer from model output."""
142
+ # Use the options for the current question if available
143
+ # Note: We can't easily get the question index here, so we'll use a simpler approach
144
+ return extract_mmstar_answer(output)
145
+
146
+ def compute_accuracy(
147
+ self, predictions: List[Any], labels: List[Any]
148
+ ) -> Optional[float]:
149
+ """Compute accuracy for MMStar by comparing answer choices."""
150
+ if not labels or len(labels) == 0:
151
+ return None
152
+ if all(label is None for label in labels):
153
+ return None
154
+
155
+ correct = 0
156
+ valid_count = 0
157
+ for pred, label in zip(predictions, labels):
158
+ if label is not None:
159
+ valid_count += 1
160
+ if pred is not None:
161
+ # Normalize to uppercase for comparison
162
+ pred_normalized = str(pred).strip().upper()
163
+ label_normalized = str(label).strip().upper()
164
+ if pred_normalized == label_normalized:
165
+ correct += 1
166
+
167
+ return correct / valid_count if valid_count > 0 else 0.0
168
+
169
+ def create_sgl_function(self):
170
+ """Create SGL function for MMStar (image-based Q&A)."""
171
+ return create_image_sgl_function(
172
+ function_name="get_mmstar_answer",
173
+ answer_key="answer",
174
+ max_tokens=self.get_max_new_tokens(),
175
+ )
176
+
177
+ def run(self, *args, **kwargs):
178
+ """Run benchmark and clean up cache directory."""
179
+ try:
180
+ return super().run(*args, **kwargs)
181
+ finally:
182
+ # Clean up cache directory
183
+ if self.cache_dir and os.path.exists(self.cache_dir):
184
+ shutil.rmtree(self.cache_dir)
185
+ print(f"Deleted temporary directory: {self.cache_dir}")
progress/SpecForge/benchmarks/benchmarker/mtbench.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MT-Bench benchmark evaluation script.
3
+ Adapted from https://github.com/chromecast56/sglang/blob/6f145d2eadb93a116134f703358ce76f15381045/benchmark/mtbench/bench_sglang.py
4
+ """
5
+
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ from sglang.utils import download_and_cache_file, read_jsonl
9
+
10
+ from .base import Benchmarker
11
+ from .registry import BENCHMARKS
12
+ from .utils import create_multi_turn_sgl_function
13
+
14
+ SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
15
+
16
+
17
+ @BENCHMARKS.register("mtbench")
18
+ class MTBenchBenchmarker(Benchmarker):
19
+ """MT-Bench benchmark implementation."""
20
+
21
+ def __init__(
22
+ self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
23
+ ):
24
+ # support categorical data for mtbench
25
+ if subset is None:
26
+ subset = ["all"]
27
+ super().__init__(num_samples, subset)
28
+
29
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[None]]:
30
+ """Load and preprocess MT-Bench dataset."""
31
+ url = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl"
32
+ download_and_cache_file(url, filename="mtbench.jsonl")
33
+ questions_data = list(read_jsonl("mtbench.jsonl"))
34
+ questions_data = questions_data
35
+
36
+ questions = [
37
+ {"question_1": q["turns"][0], "question_2": q["turns"][1]}
38
+ for q in questions_data
39
+ ]
40
+ # MT-Bench doesn't have labels for accuracy computation
41
+ labels = [None] * len(questions)
42
+
43
+ if self.num_samples is not None:
44
+ questions = questions[: self.num_samples]
45
+ labels = labels[: self.num_samples]
46
+ return questions, labels
47
+
48
+ def create_sgl_function(self):
49
+ """Create SGL function for MT-Bench (2-turn conversation)."""
50
+ return create_multi_turn_sgl_function(
51
+ function_name="answer_mt_bench",
52
+ system_prompt=SYSTEM_PROMPT,
53
+ num_turns=2,
54
+ max_tokens=self.get_max_new_tokens(),
55
+ )
56
+
57
+ def get_answer_keys(self) -> List[str]:
58
+ """Return answer keys for multi-turn conversation."""
59
+ return ["answer_1", "answer_2"]
progress/SpecForge/benchmarks/benchmarker/registry.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class BenchmarkRegistry:
2
+
3
+ def __init__(self):
4
+ self.benchmarks = {}
5
+
6
+ def register(self, name: str):
7
+ """
8
+ Usage:
9
+ ```python
10
+ BENCHMARKS = BenchmarkRegistry()
11
+
12
+ BENCHMARKS.register("aime")
13
+ class AIMEBenchmarker(Benchmarker):
14
+ ...
15
+ ```
16
+ """
17
+
18
+ def wrapper(cls):
19
+ self.benchmarks[name] = cls
20
+ return cls
21
+
22
+ return wrapper
23
+
24
+ def get(self, name: str) -> type:
25
+ """
26
+ Get the benchmark class by name.
27
+ """
28
+ return self.benchmarks[name]
29
+
30
+
31
+ BENCHMARKS = BenchmarkRegistry()
progress/SpecForge/benchmarks/benchmarker/simpleqa.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ from datasets import load_dataset
4
+
5
+ from .base import Benchmarker
6
+ from .registry import BENCHMARKS
7
+ from .utils import create_simple_sgl_function
8
+
9
+
10
+ def generate_question(row: Dict[str, Any]) -> str:
11
+ question = row["problem"].strip()
12
+ return question
13
+
14
+
15
+ @BENCHMARKS.register("simpleqa")
16
+ class SimpleQABenchmarker(Benchmarker):
17
+ """SimpleQA benchmark implementation."""
18
+
19
+ def __init__(self, num_samples: Optional[int] = None):
20
+ super().__init__(num_samples, None)
21
+
22
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
23
+ # Read data
24
+ ds = load_dataset("basicv8vc/SimpleQA")["test"]
25
+
26
+ questions = []
27
+ labels = []
28
+ for i in range((len(ds))):
29
+ if self.num_samples is not None and i >= self.num_samples:
30
+ break
31
+
32
+ question_text = generate_question(ds[i])
33
+ questions.append({"question": question_text})
34
+ labels.append(None)
35
+ return questions, labels
36
+
37
+ def create_sgl_function(self):
38
+ return create_simple_sgl_function(
39
+ function_name="get_simpleqa_answer",
40
+ answer_key="answer",
41
+ max_tokens=self.get_max_new_tokens(),
42
+ )
progress/SpecForge/benchmarks/benchmarker/utils.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for benchmark scripts.
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Callable, Dict, List, Optional
7
+
8
+ import numpy as np
9
+ import sglang as sgl
10
+
11
+
12
+ @dataclass
13
+ class BenchmarkMetrics:
14
+ """Container for benchmark performance metrics."""
15
+
16
+ latency: float
17
+ output_throughput: float
18
+ accept_length: float
19
+ accuracy: Optional[float] = None
20
+ num_questions: int = 0
21
+ num_valid_predictions: int = 0
22
+ categorical_performance: Optional[Dict[str, "BenchmarkMetrics"]] = None
23
+
24
+
25
+ def compute_metrics(
26
+ states: List[Any],
27
+ latency: float,
28
+ answer_key: str = "answer",
29
+ additional_answer_keys: Optional[List[str]] = None,
30
+ ) -> BenchmarkMetrics:
31
+ """
32
+ Compute performance metrics from SGLang states.
33
+
34
+ Args:
35
+ states: List of SGLang state objects from run_batch
36
+ latency: Total latency in seconds
37
+ answer_key: Primary key for answer in state meta info
38
+ additional_answer_keys: Additional keys to include in token count (e.g., ["answer_1", "answer_2"])
39
+
40
+ Returns:
41
+ BenchmarkMetrics object with computed metrics
42
+ """
43
+ # Compute output tokens
44
+ num_output_tokens = 0
45
+ if additional_answer_keys:
46
+ for key in [answer_key] + additional_answer_keys:
47
+ num_output_tokens += sum(
48
+ s.get_meta_info(key)["completion_tokens"] for s in states
49
+ )
50
+ else:
51
+ num_output_tokens = sum(
52
+ s.get_meta_info(answer_key)["completion_tokens"] for s in states
53
+ )
54
+
55
+ output_throughput = num_output_tokens / latency if latency > 0 else 0.0
56
+
57
+ # Compute accept length (speculative decoding metric)
58
+ has_verify = "spec_verify_ct" in states[0].get_meta_info(answer_key)
59
+ if has_verify:
60
+ num_verify_tokens = 0
61
+ if additional_answer_keys:
62
+ for key in [answer_key] + additional_answer_keys:
63
+ num_verify_tokens += sum(
64
+ s.get_meta_info(key).get("spec_verify_ct", 0) for s in states
65
+ )
66
+ else:
67
+ num_verify_tokens = sum(
68
+ s.get_meta_info(answer_key).get("spec_verify_ct", 0) for s in states
69
+ )
70
+
71
+ if num_verify_tokens == 0:
72
+ accept_length = 1.0
73
+ else:
74
+ accept_length = num_output_tokens / num_verify_tokens
75
+ else:
76
+ accept_length = 1.0
77
+
78
+ return BenchmarkMetrics(
79
+ latency=latency,
80
+ output_throughput=output_throughput,
81
+ accept_length=accept_length,
82
+ num_questions=len(states),
83
+ )
84
+
85
+
86
+ def print_results(
87
+ metrics_list: List[BenchmarkMetrics],
88
+ benchmark_name: str,
89
+ show_accuracy: bool = False,
90
+ ):
91
+ """
92
+ Print benchmark results in a formatted way.
93
+
94
+ Args:
95
+ metrics_list: List of BenchmarkMetrics from multiple runs
96
+ benchmark_name: Name of the benchmark
97
+ show_accuracy: Whether to show accuracy metrics
98
+ """
99
+ avg_latency = np.mean([m.latency for m in metrics_list])
100
+ avg_throughput = np.mean([m.output_throughput for m in metrics_list])
101
+ avg_accept_length = np.mean([m.accept_length for m in metrics_list])
102
+
103
+ print(f"\n{'='*50}")
104
+ print(f"{benchmark_name} Evaluation Results")
105
+ print(f"{'='*50}")
106
+ print(f"Number of questions: {metrics_list[0].num_questions}")
107
+ if show_accuracy:
108
+ if metrics_list[0].accuracy is not None:
109
+ avg_accuracy = np.mean(
110
+ [m.accuracy for m in metrics_list if m.accuracy is not None]
111
+ )
112
+ print(f"Average Accuracy: {avg_accuracy:.4f} ({avg_accuracy*100:.2f}%)")
113
+ else:
114
+ print(f"Average Accuracy: None")
115
+ print(f"Average Latency: {avg_latency:.3f} s")
116
+ print(f"Average Output throughput: {avg_throughput:.3f} token/s")
117
+ print(f"Average Accept length: {avg_accept_length:.3f}")
118
+ print(f"{'='*50}\n")
119
+
120
+
121
+ def create_simple_sgl_function(
122
+ function_name: str = "get_answer",
123
+ answer_key: str = "answer",
124
+ system_prompt: Optional[str] = None,
125
+ max_tokens: int = 2048,
126
+ stop: Optional[List[str]] = None,
127
+ user_prefix: Optional[str] = None,
128
+ ) -> Callable:
129
+ """
130
+ Create a simple SGL function for single-turn Q&A.
131
+
132
+ Args:
133
+ function_name: Name of the function
134
+ answer_key: Key for storing the answer
135
+ system_prompt: Optional system prompt
136
+ max_tokens: Maximum tokens to generate
137
+ stop: Optional stop sequences
138
+ user_prefix: Optional suffix to append to user message (appended after question)
139
+
140
+ Returns:
141
+ SGL function decorated with @sgl.function
142
+ """
143
+
144
+ @sgl.function
145
+ def sgl_func(s, question):
146
+ if system_prompt:
147
+ s += sgl.system(system_prompt)
148
+ user_content = question
149
+ if user_prefix:
150
+ user_content = question + user_prefix
151
+ s += sgl.user(user_content)
152
+ gen_kwargs = {"max_tokens": max_tokens}
153
+ if stop:
154
+ gen_kwargs["stop"] = stop
155
+ s += sgl.assistant(sgl.gen(answer_key, **gen_kwargs))
156
+
157
+ sgl_func.__name__ = function_name
158
+ return sgl_func
159
+
160
+
161
+ def create_few_shot_sgl_function(
162
+ few_shot_examples: str,
163
+ function_name: str = "few_shot_answer",
164
+ answer_key: str = "answer",
165
+ max_tokens: int = 512,
166
+ stop: Optional[List[str]] = None,
167
+ ) -> Callable:
168
+ """
169
+ Create an SGL function for few-shot learning.
170
+
171
+ Args:
172
+ few_shot_examples: String containing few-shot examples
173
+ function_name: Name of the function
174
+ answer_key: Key for storing the answer
175
+ max_tokens: Maximum tokens to generate
176
+ stop: Optional stop sequences
177
+
178
+ Returns:
179
+ SGL function decorated with @sgl.function
180
+ """
181
+
182
+ @sgl.function
183
+ def sgl_func(s, question):
184
+ s += few_shot_examples + question
185
+ gen_kwargs = {"max_tokens": max_tokens}
186
+ if stop:
187
+ gen_kwargs["stop"] = stop
188
+ s += sgl.gen(answer_key, **gen_kwargs)
189
+
190
+ sgl_func.__name__ = function_name
191
+ return sgl_func
192
+
193
+
194
+ def create_multi_turn_sgl_function(
195
+ function_name: str = "multi_turn_answer",
196
+ system_prompt: Optional[str] = None,
197
+ num_turns: int = 2,
198
+ max_tokens: int = 2048,
199
+ ) -> Callable:
200
+ """
201
+ Create an SGL function for multi-turn conversations (e.g., MT-Bench with 2 turns).
202
+
203
+ Args:
204
+ function_name: Name of the function
205
+ system_prompt: Optional system prompt
206
+ num_turns: Number of conversation turns (default: 2)
207
+ max_tokens: Maximum tokens to generate per turn
208
+
209
+ Returns:
210
+ SGL function decorated with @sgl.function
211
+ """
212
+ if num_turns == 2:
213
+ # Most common case: 2-turn conversation
214
+ @sgl.function
215
+ def sgl_func(s, question_1, question_2):
216
+ if system_prompt:
217
+ s += sgl.system(system_prompt)
218
+ s += sgl.user(question_1)
219
+ s += sgl.assistant(sgl.gen("answer_1", max_tokens=max_tokens))
220
+ s += sgl.user(question_2)
221
+ s += sgl.assistant(sgl.gen("answer_2", max_tokens=max_tokens))
222
+
223
+ else:
224
+ # Generic case: create function with dynamic number of turns
225
+ # Note: This requires the caller to pass arguments as a dict
226
+ @sgl.function
227
+ def sgl_func(s, **kwargs):
228
+ if system_prompt:
229
+ s += sgl.system(system_prompt)
230
+ for i in range(num_turns):
231
+ question_key = f"question_{i+1}"
232
+ answer_key = f"answer_{i+1}"
233
+ if question_key in kwargs:
234
+ s += sgl.user(kwargs[question_key])
235
+ s += sgl.assistant(sgl.gen(answer_key, max_tokens=max_tokens))
236
+
237
+ sgl_func.__name__ = function_name
238
+ return sgl_func
239
+
240
+
241
+ def create_image_sgl_function(
242
+ function_name: str = "get_image_answer",
243
+ answer_key: str = "answer",
244
+ max_tokens: int = 2048,
245
+ ) -> Callable:
246
+ """
247
+ Create an SGL function for image-based Q&A.
248
+
249
+ Args:
250
+ function_name: Name of the function
251
+ answer_key: Key for storing the answer
252
+ max_tokens: Maximum tokens to generate
253
+
254
+ Returns:
255
+ SGL function decorated with @sgl.function
256
+ """
257
+
258
+ @sgl.function
259
+ def sgl_func(s, image_path, question, **kwargs):
260
+ """
261
+ The body of the SGL function: constructs a multimodal conversation flow.
262
+
263
+ - First, it inputs an image + text question as 'user'.
264
+ - Then, it generates an answer as 'assistant', binding the response to the specified `answer_key`.
265
+
266
+ Note: sgl.image() automatically encodes the image into a format supported by the model for multimodal input.
267
+ """
268
+ # User input: Image + Text question
269
+ s += sgl.user(sgl.image(image_path) + question)
270
+ s += sgl.assistant(sgl.gen(answer_key, max_tokens=max_tokens))
271
+
272
+ sgl_func.__name__ = function_name
273
+ return sgl_func
progress/SpecForge/cache/compiled_kernels/26/c26l7dxpqbfol7d62sqakxdv4rgyh27yhm4hrctevbkw5t6kekia.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
101
+
102
+ ZQ = 1
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = ks0
106
+ ZKV = 1
107
+ KV_LEN = ks1
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 1
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = ks2
148
+ stride_kv_idx_h = ks3*ks4
149
+ stride_kv_idx_m = ks4
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 6
245
+ stride_q_idx_h = 36
246
+ stride_q_idx_n = 6
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
345
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831845
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = ks0
385
+ KV_LEN = ks1
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831845
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = (m)
480
+ tmp2 = tl.full([1], 0, tl.int32)
481
+ tmp3 = tmp1 < tmp2
482
+ tmp4 = (n)
483
+ tmp5 = tmp4 <= tmp1
484
+ tmp6 = tmp3 & tmp5
485
+ tmp7 = tmp1 >= tmp2
486
+ tmp8 = tmp4 < tmp2
487
+ tmp9 = tmp7 & tmp8
488
+ tmp10 = tmp8 == 0
489
+ tmp11 = tmp7 & tmp10
490
+ tmp12 = tmp1 - tmp2
491
+ tmp13 = tl.full([1], 16, tl.int32)
492
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
493
+ tmp15 = tmp4 - tmp2
494
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
495
+ tmp17 = tmp14 == tmp16
496
+ tmp18 = tmp11 & tmp17
497
+ tmp19 = tmp9 | tmp18
498
+ tmp20 = tmp6 | tmp19
499
+ mask_mod_output = tmp20
500
+
501
+
502
+ # apply mask for partial masked block
503
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
504
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
505
+ if not PRESCALE_QK:
506
+ post_mod_scores *= RCP_LN2
507
+ p = tl.math.exp2(post_mod_scores - lse)
508
+ # Compute dP and dS.
509
+ # NB reversed order to since V is transposed
510
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
511
+
512
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
513
+ ds = p * (dp - Di[:, None])
514
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
515
+ tmp21 = (ds)
516
+ grad_scores = tmp21
517
+
518
+
519
+ if not IS_DIVISIBLE:
520
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
521
+
522
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
523
+ if WRITE_DQ:
524
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
525
+
526
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
527
+ ds = grad_scores
528
+
529
+ if not IS_FULL_BLOCKS:
530
+ # (grads) apply mask for partially unmasked block
531
+ ds = tl.where(mask_mod_output, ds, 0.0)
532
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
533
+ ds = ds.to(MATMUL_PRECISION)
534
+ # Compute dQ.
535
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
536
+
537
+ return dq
538
+
539
+
540
+ @triton.jit
541
+ def bwd_dkdv_inner(
542
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
543
+ Q, DO, DELTA, LSE, # pointers
544
+ dk, dv, k, v,
545
+ off_z, off_hq, offs_n1, offs_m1,
546
+ stride_qm, stride_qd, stride_dom, stride_dod,
547
+ q_indices, sparse_q_num_blocks,
548
+ MATMUL_PRECISION,
549
+ IS_FULL_BLOCKS,
550
+ ):
551
+ PRESCALE_QK : tl.constexpr = False
552
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
553
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
554
+ WRITE_DQ : tl.constexpr = True
555
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
556
+ OUTPUT_MAX : tl.constexpr = False
557
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
558
+ IS_DIVISIBLE : tl.constexpr = False
559
+ SM_SCALE : tl.constexpr = 0.08838834764831845
560
+ GQA_SHARED_HEADS : tl.constexpr = 4
561
+ HAS_FULL_BLOCKS : tl.constexpr = True
562
+ QK_HEAD_DIM : tl.constexpr = 128
563
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
564
+ V_HEAD_DIM : tl.constexpr = 128
565
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
566
+ SAFE_HEAD_DIM : tl.constexpr = True
567
+ BLOCK_M1 : tl.constexpr = 64
568
+ BLOCK_N1 : tl.constexpr = 128
569
+ BLOCK_M2 : tl.constexpr = 128
570
+ BLOCK_N2 : tl.constexpr = 64
571
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
572
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
573
+ INDEX_DTYPE : tl.constexpr = tl.int32
574
+
575
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
576
+ RCP_LN2: tl.constexpr = 1.44269504
577
+ Q_LEN = ks0
578
+ KV_LEN = ks1
579
+
580
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
581
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
582
+
583
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
584
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
585
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
586
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
587
+
588
+ # The minimum is needed to handle the case where we run with a super large
589
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
590
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
591
+
592
+ for start_m in range(0, hi):
593
+ dk, dv = bwd_dkdv_block_mn(
594
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
595
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
596
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
597
+ stride_qm, stride_qd, stride_dom, stride_dod,
598
+ q_indices, sparse_q_num_blocks,
599
+ MATMUL_PRECISION, RCP_LN2,
600
+ IS_FULL_BLOCKS,
601
+ )
602
+ # Increment pointers.
603
+ offset = get_offset_for_next_block(
604
+ start_m, q_indices, sparse_q_num_blocks,
605
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
606
+ )
607
+
608
+ qT_ptrs += offset * stride_qm
609
+ do_ptrs += offset * stride_dom
610
+ offs_m1 += offset
611
+
612
+ return dk, dv
613
+
614
+
615
+ @triton.jit
616
+ def bwd_dkdv_block_mn(
617
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
618
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
619
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
620
+ stride_qm, stride_qd, stride_dom, stride_dod,
621
+ q_indices, sparse_q_num_blocks,
622
+ MATMUL_PRECISION, RCP_LN2,
623
+ IS_FULL_BLOCKS,
624
+ ):
625
+ PRESCALE_QK : tl.constexpr = False
626
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
627
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
628
+ WRITE_DQ : tl.constexpr = True
629
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
630
+ OUTPUT_MAX : tl.constexpr = False
631
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
632
+ IS_DIVISIBLE : tl.constexpr = False
633
+ SM_SCALE : tl.constexpr = 0.08838834764831845
634
+ GQA_SHARED_HEADS : tl.constexpr = 4
635
+ HAS_FULL_BLOCKS : tl.constexpr = True
636
+ QK_HEAD_DIM : tl.constexpr = 128
637
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
638
+ V_HEAD_DIM : tl.constexpr = 128
639
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
640
+ SAFE_HEAD_DIM : tl.constexpr = True
641
+ BLOCK_M1 : tl.constexpr = 64
642
+ BLOCK_N1 : tl.constexpr = 128
643
+ BLOCK_M2 : tl.constexpr = 128
644
+ BLOCK_N2 : tl.constexpr = 64
645
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
646
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
647
+ INDEX_DTYPE : tl.constexpr = tl.int32
648
+
649
+
650
+ # NB reversed order since Q is transposed
651
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
652
+ # Load LSE before computing qk to reduce pipeline stall.
653
+ if IS_DIVISIBLE:
654
+ lse = tl.load(LSE + offs_m1)
655
+ else:
656
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
657
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
658
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
659
+ if not PRESCALE_QK:
660
+ qkT *= SM_SCALE
661
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
662
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
663
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
664
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
665
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
666
+
667
+ pre_mod_scores = qkT
668
+ tmp22 = (qkT)
669
+ post_mod_scores = tmp22
670
+
671
+
672
+
673
+ if not IS_DIVISIBLE:
674
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
675
+
676
+ if not IS_FULL_BLOCKS:
677
+ tmp23 = (m)
678
+ tmp24 = tl.full([1], 0, tl.int32)
679
+ tmp25 = tmp23 < tmp24
680
+ tmp26 = (n)
681
+ tmp27 = tmp26 <= tmp23
682
+ tmp28 = tmp25 & tmp27
683
+ tmp29 = tmp23 >= tmp24
684
+ tmp30 = tmp26 < tmp24
685
+ tmp31 = tmp29 & tmp30
686
+ tmp32 = tmp30 == 0
687
+ tmp33 = tmp29 & tmp32
688
+ tmp34 = tmp23 - tmp24
689
+ tmp35 = tl.full([1], 16, tl.int32)
690
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
691
+ tmp37 = tmp26 - tmp24
692
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
693
+ tmp39 = tmp36 == tmp38
694
+ tmp40 = tmp33 & tmp39
695
+ tmp41 = tmp31 | tmp40
696
+ tmp42 = tmp28 | tmp41
697
+ mask_mod_output = tmp42
698
+
699
+ # (grads) apply mask for fully masked block
700
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
701
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
702
+ if not PRESCALE_QK:
703
+ post_mod_scores *= RCP_LN2
704
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
705
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
706
+ # Compute dV.
707
+ ppT = pT
708
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
709
+ if IS_DIVISIBLE:
710
+ Di = tl.load(DELTA + offs_m1)
711
+ else:
712
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
713
+ # Compute dP and dS.
714
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
715
+ dsT = pT * (dpT - Di[None, :])
716
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
717
+ tmp43 = (dsT)
718
+ grad_scores = tmp43
719
+
720
+
721
+
722
+ if not IS_DIVISIBLE:
723
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
724
+
725
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
726
+ if not WRITE_DQ:
727
+ idx_b = off_z
728
+ idx_h = off_hq
729
+ idx_m = m
730
+ idx_n = n
731
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
732
+
733
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
734
+ dsT = grad_scores
735
+ if not IS_FULL_BLOCKS:
736
+ # (grads) apply mask for partially unmasked block
737
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
738
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
739
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
740
+
741
+ return dk, dv
742
+
743
+ # Utility triton funcs
744
+ @triton.jit
745
+ def get_offset_for_next_block(
746
+ loop_iter, col_indices, total_blocks,
747
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
748
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
749
+ ):
750
+ if BLOCKS_ARE_CONTIGUOUS:
751
+ return BLOCK
752
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
753
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
754
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
755
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
756
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
757
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
758
+ return offset
759
+
760
+ @triton.jit
761
+ def get_bounded_indices(indices, max_len=None):
762
+ return indices % max_len if max_len is not None else indices
763
+
764
+ @triton.jit
765
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
766
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
767
+ return tl.load(block_ptr)
768
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
769
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
770
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
771
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
772
+ else:
773
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
774
+
775
+ @triton.jit
776
+ def load_checked_2d(
777
+ ptr,
778
+ offs_m,
779
+ offs_n,
780
+ stride_m,
781
+ stride_n,
782
+ IS_DIVISIBLE_M: tl.constexpr,
783
+ IS_DIVISIBLE_N: tl.constexpr,
784
+ M_LEN: tl.constexpr,
785
+ N_LEN: tl.constexpr,
786
+ ):
787
+ # Calculate final pointer if strides are provided
788
+ if stride_m is not None and stride_n is not None:
789
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
790
+
791
+ # Handle all masking cases
792
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
793
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
794
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
795
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
796
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
797
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
798
+ else: # Both divisible
799
+ return tl.load(ptr)
progress/SpecForge/cache/compiled_kernels/2d/c2d4e47kqxxnp6455gvkteqq3r336462zkbitosyeko6znxktn2b.py ADDED
@@ -0,0 +1,879 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['3_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7g/c7gxkvfztxetv7w7i4s7mr7dlsdda3dfgq3f3uijvhozq6ggk4o4.py
38
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
39
+ # Source node to ATen node mapping:
40
+ # flex_attention => flex_attention
41
+ # Graph fragment:
42
+ # %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=arg1_1]
43
+ # %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=arg3_1]
44
+ # %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:3" = PlaceHolder[target=arg5_1]
45
+ # %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf0]
46
+ # %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1]
47
+ # %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=arg9_1]
48
+ # %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=arg6_1]
49
+ # %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=arg10_1]
50
+ # %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=arg11_1]
51
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
52
+ # return %buf2
53
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
54
+ import triton
55
+ import triton.language as tl
56
+
57
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
58
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
59
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
60
+
61
+ @triton_heuristics.template(
62
+
63
+ num_stages=3,
64
+ num_warps=2,
65
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}},
67
+
68
+ )
69
+ @triton.jit
70
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
71
+ PRESCALE_QK : tl.constexpr = False
72
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
73
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
74
+ WRITE_DQ : tl.constexpr = True
75
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
76
+ OUTPUT_MAX : tl.constexpr = False
77
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
78
+ IS_DIVISIBLE : tl.constexpr = False
79
+ GQA_SHARED_HEADS : tl.constexpr = 4
80
+ HAS_FULL_BLOCKS : tl.constexpr = True
81
+ SM_SCALE : tl.constexpr = 0.08838834764831845
82
+ SPLIT_KV : tl.constexpr = 32
83
+ QK_HEAD_DIM : tl.constexpr = 128
84
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
85
+ V_HEAD_DIM : tl.constexpr = 128
86
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
87
+ SAFE_HEAD_DIM : tl.constexpr = True
88
+ BLOCK_M : tl.constexpr = 512
89
+ SAFE_M_BOUNDARY : tl.constexpr = False
90
+ SAFE_N_BOUNDARY : tl.constexpr = True
91
+ BLOCK_N : tl.constexpr = 64
92
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
93
+ USE_TMA : tl.constexpr = False
94
+ INDEX_DTYPE : tl.constexpr = tl.int32
95
+ Q = arg_Q
96
+ K = arg_K
97
+ V = arg_V
98
+ M = arg_M
99
+ L = arg_L
100
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
101
+ KV_IDX = arg_KV_IDX
102
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
103
+ FULL_KV_IDX = arg_FULL_KV_IDX
104
+
105
+ # Sub notation for this kernel:
106
+ # Q: Query, K: Key, V: Value
107
+ # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
108
+ # M: Number of queries, N: Number of keys/values
109
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
110
+ # V_HEAD_DIM: The dimension of the value embeddings
111
+ # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
112
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
113
+ # (Modifiable) Config options:
114
+ # SPLIT_KV: number of blocks K & V are split into
115
+ # TILE_KV: length of each local KV split
116
+ # BLOCK_M: block size that Q is padded along seqlen dim.
117
+ # BLOCK_N: block size of K & V along N dimension.
118
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
119
+ #
120
+ # change of base out of the loop
121
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
122
+ # is not masked out? If so, we can skip an extra safety check
123
+ # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
124
+ # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
125
+
126
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
127
+ #
128
+ # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
129
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
130
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
131
+ #
132
+ #
133
+ # Output: ACC output accumulated across local KV split.
134
+
135
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
136
+
137
+ # Define Q Strides
138
+ stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1
139
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
140
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
141
+ stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1
142
+ stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1
143
+
144
+
145
+ Z = 1
146
+ ZKV = 1
147
+ HKV = 8
148
+ G: tl.constexpr = GQA_SHARED_HEADS
149
+ HQ = HKV * G
150
+ Q_LEN = ks0
151
+ KV_LEN = ks1
152
+
153
+ MATMUL_PRECISION = Q.dtype.element_ty
154
+
155
+ # Make sure each split is a multiple of BLOCK_N
156
+ TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
157
+ TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
158
+ TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
159
+
160
+ off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV
161
+ off_zkv = off_z % ZKV
162
+ off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV
163
+ off_t = tl.program_id(1).to(INDEX_DTYPE)
164
+
165
+ q_offset = off_z * stride_qz + off_hkv * stride_qh
166
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
167
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
168
+
169
+ K = K + k_offset
170
+ V = V + v_offset
171
+
172
+ SPARSE_Z = 1
173
+ SPARSE_HQ = 1
174
+
175
+ sparse_idx_z = off_z % SPARSE_Z
176
+ sparse_idx_h = off_hkv % SPARSE_HQ
177
+
178
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
179
+ SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
180
+
181
+ # initialize pointer to m and l
182
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
183
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
184
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
185
+
186
+ # initialize offsets
187
+ tl.device_assert(BLOCK_M % G == 0)
188
+ BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
189
+ off_g = tl.arange(0, G) # [G]
190
+ offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
191
+ offs_hq = offs_g + off_hkv * G
192
+ off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
193
+ offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
194
+ offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
195
+ offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
196
+
197
+ # Get HZ offsets for KV_NUM_BLKS and KV_IDX
198
+ stride_block_z, stride_block_h, stride_block_row = 1, 1, 1
199
+ sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
200
+ stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1
201
+ sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
202
+
203
+ # Calculate KV blocks that belong this CTA.
204
+ block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
205
+ block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
206
+
207
+ q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
208
+
209
+ if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
210
+ q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
211
+ elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
212
+ q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
213
+ elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
214
+ q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
215
+ else:
216
+ q = tl.load(Q + q_offset + q_range)
217
+
218
+ q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
219
+
220
+
221
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
222
+ # find first kv block we are loading and the number of blocks we are loading
223
+ # Offset the kv_indices tensor by the correct batch and head
224
+ kv_indices = KV_IDX + sparse_idx_hz_offset
225
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
226
+ MAX_KV_IDX = 1
227
+ indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
228
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
229
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
230
+ # first kv block we're loading
231
+
232
+ # last valid block according to sparse mask
233
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
234
+
235
+ offs_n = tl.arange(0, BLOCK_N) + off_n
236
+
237
+ desc_k = None
238
+ desc_v = None
239
+
240
+ acc, l_i, m_i = forward_inner(
241
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
242
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
243
+ # accumulatd values
244
+ acc, l_i, m_i,
245
+ #offsets
246
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
247
+ off_n,
248
+ #block sparse data
249
+ kv_indices, kv_num_blocks,
250
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
251
+ MATMUL_PRECISION,
252
+ stride_kk, stride_kn, stride_vn, stride_vk,
253
+ IS_FULL_BLOCKS=False,
254
+ )
255
+
256
+
257
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
258
+ # We know these blocks are guaranteed to be "full", so we don't need to
259
+ # apply mask_mod to them - only score_mod
260
+ if HAS_FULL_BLOCKS:
261
+ kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
262
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
263
+ # Assign full block in a reverse order for off_t. Prioritize the last CTA.
264
+ block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
265
+ block_n_end = block_n_start + TILE_KV_MULTIPLE
266
+ indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
267
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
268
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
269
+
270
+ # last valid block according to sparse mask
271
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
272
+
273
+ offs_n = tl.arange(0, BLOCK_N) + off_n
274
+
275
+ acc, l_i, m_i = forward_inner(
276
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
277
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
278
+ # accumulatd values
279
+ acc, l_i, m_i,
280
+ #offsets
281
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
282
+ off_n,
283
+ #block sparse data
284
+ kv_indices, kv_num_blocks,
285
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
286
+ MATMUL_PRECISION,
287
+ stride_kk, stride_kn, stride_vn, stride_vk,
288
+ IS_FULL_BLOCKS=True,
289
+ )
290
+
291
+ m_offset = off_t * stride_mt + off_z * stride_mz
292
+ l_offset = off_t * stride_lt + off_z * stride_lz
293
+
294
+ M_block_ptr = tl.make_block_ptr(
295
+ base=M + m_offset,
296
+ shape=(G, Q_LEN), # (G, M)
297
+ strides=(stride_mh, stride_mm),
298
+ offsets=(off_hkv*G, 0),
299
+ block_shape=(G, BLOCK_M_PER_HQ),
300
+ order=(1, 0)
301
+ )
302
+ L_block_ptr = tl.make_block_ptr(
303
+ base=L + l_offset,
304
+ shape=(G, Q_LEN), # (G, M)
305
+ strides=(stride_lh, stride_lm),
306
+ offsets=(off_hkv*G, 0),
307
+ block_shape=(G, BLOCK_M_PER_HQ),
308
+ order=(1, 0)
309
+ )
310
+
311
+ # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
312
+ m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
313
+ l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
314
+ if SAFE_M_BOUNDARY:
315
+ tl.store(M_block_ptr, m_i)
316
+ tl.store(L_block_ptr, l_i)
317
+ else:
318
+ tl.store(M_block_ptr, m_i, boundary_check=(1,))
319
+ tl.store(L_block_ptr, l_i, boundary_check=(1,))
320
+
321
+ # -- store output
322
+ idx_z = off_z
323
+ idx_t = off_t
324
+ idx_hq = off_hkv*G + off_g[:, None, None]
325
+ idx_m = off_m[None, :, None]
326
+ idx_d = offs_vd[None, None, :]
327
+
328
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
329
+ acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
330
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0
331
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask)
332
+
333
+
334
+ # Utility triton funcs
335
+ @triton.jit
336
+ def get_offset_for_next_block(
337
+ loop_iter, col_indices, total_blocks,
338
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
339
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
340
+ ):
341
+ if BLOCKS_ARE_CONTIGUOUS:
342
+ return BLOCK
343
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
344
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
345
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
346
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
347
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
348
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
349
+ return offset
350
+
351
+ @triton.jit
352
+ def get_bounded_indices(indices, max_len=None):
353
+ return indices % max_len if max_len is not None else indices
354
+
355
+ @triton.jit
356
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
357
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
358
+ return tl.load(block_ptr)
359
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
360
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
361
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
362
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
363
+ else:
364
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
365
+
366
+ @triton.jit
367
+ def load_checked_2d(
368
+ ptr,
369
+ offs_m,
370
+ offs_n,
371
+ stride_m,
372
+ stride_n,
373
+ IS_DIVISIBLE_M: tl.constexpr,
374
+ IS_DIVISIBLE_N: tl.constexpr,
375
+ M_LEN: tl.constexpr,
376
+ N_LEN: tl.constexpr,
377
+ ):
378
+ # Calculate final pointer if strides are provided
379
+ if stride_m is not None and stride_n is not None:
380
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
381
+
382
+ # Handle all masking cases
383
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
384
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
385
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
386
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
387
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
388
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
389
+ else: # Both divisible
390
+ return tl.load(ptr)
391
+
392
+
393
+ # Common Imports
394
+ @triton.jit
395
+ def forward_block_mn(
396
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
397
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
398
+ # accumulated values
399
+ acc, l_i, m_i,
400
+ # Offsets
401
+ off_z, off_h, offs_m, offs_n,
402
+ # Offsets needed for TMA loads
403
+ kv_start,
404
+ kv_offset,
405
+ MATMUL_PRECISION, RCP_LN2,
406
+ # Strides for K and V
407
+ stride_kk, stride_kn, stride_vn, stride_vk,
408
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
409
+
410
+ ):
411
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
412
+ PRESCALE_QK : tl.constexpr = False
413
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
414
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
415
+ WRITE_DQ : tl.constexpr = True
416
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
417
+ OUTPUT_MAX : tl.constexpr = False
418
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
419
+ IS_DIVISIBLE : tl.constexpr = False
420
+ GQA_SHARED_HEADS : tl.constexpr = 4
421
+ HAS_FULL_BLOCKS : tl.constexpr = True
422
+ SM_SCALE : tl.constexpr = 0.08838834764831845
423
+ SPLIT_KV : tl.constexpr = 32
424
+ QK_HEAD_DIM : tl.constexpr = 128
425
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
426
+ V_HEAD_DIM : tl.constexpr = 128
427
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
428
+ SAFE_HEAD_DIM : tl.constexpr = True
429
+ BLOCK_M : tl.constexpr = 512
430
+ SAFE_M_BOUNDARY : tl.constexpr = False
431
+ SAFE_N_BOUNDARY : tl.constexpr = True
432
+ BLOCK_N : tl.constexpr = 64
433
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
434
+ USE_TMA : tl.constexpr = False
435
+ INDEX_DTYPE : tl.constexpr = tl.int32
436
+
437
+
438
+ # -- load k --
439
+ # NB reversed order to since K is transposed
440
+ kv_base_offset = kv_start + kv_offset
441
+
442
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
443
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
444
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
445
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
446
+
447
+ k = tl.trans(k)
448
+ # -- compute qk ---
449
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
450
+ if not PRESCALE_QK:
451
+ qk *= SM_SCALE
452
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
453
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
454
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
455
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
456
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
457
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
458
+
459
+ tmp0 = (qk)
460
+ post_mod_scores = tmp0
461
+
462
+
463
+ if CHECK_BLOCK_BOUNDARY:
464
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
465
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
466
+
467
+ if not IS_FULL_BLOCKS:
468
+ tmp1 = (m)
469
+ tmp2 = tl.full([1], 0, tl.int32)
470
+ tmp3 = tmp1 < tmp2
471
+ tmp4 = (n)
472
+ tmp5 = tmp4 <= tmp1
473
+ tmp6 = tmp3 & tmp5
474
+ tmp7 = tmp1 >= tmp2
475
+ tmp8 = tmp4 < tmp2
476
+ tmp9 = tmp7 & tmp8
477
+ tmp10 = tmp8 == 0
478
+ tmp11 = tmp7 & tmp10
479
+ tmp12 = tmp1 - tmp2
480
+ tmp13 = tl.full([1], 16, tl.int32)
481
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
482
+ tmp15 = tmp4 - tmp2
483
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
484
+ tmp17 = tmp14 == tmp16
485
+ tmp18 = tmp11 & tmp17
486
+ tmp19 = tmp9 | tmp18
487
+ tmp20 = tmp6 | tmp19
488
+ mask_mod_output = tmp20
489
+
490
+
491
+ if CHECK_BLOCK_BOUNDARY:
492
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
493
+ # apply mask for partially unmasked blocks
494
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
495
+
496
+ if not PRESCALE_QK:
497
+ post_mod_scores *= RCP_LN2
498
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
499
+
500
+ # -- compute scaling constant ---
501
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
502
+ if not ROWS_GUARANTEED_SAFE:
503
+ masked_out_rows = (m_ij == float("-inf"))
504
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
505
+ else:
506
+ m_ij_masked = m_ij
507
+
508
+ alpha = tl.math.exp2(m_i - m_ij_masked)
509
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
510
+
511
+ # NB: l_i update is pulled up here since it's a bit faster
512
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
513
+ # m_ij
514
+ l_i = l_i * alpha + tl.sum(p, 1)
515
+ # # -- scale and update acc --
516
+ acc = acc * alpha[:, None]
517
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
518
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
519
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
520
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
521
+
522
+ # -- update m_i
523
+ m_i = m_ij
524
+
525
+ return acc, l_i, m_i
526
+
527
+ @triton.jit
528
+ def forward_inner(
529
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
530
+ q, K, V,
531
+ desc_k, desc_v, Q_LEN, KV_LEN,
532
+ # accumulated values
533
+ acc, l_i, m_i,
534
+ # Offsets used as inputs to score_mod & mask_mod
535
+ # of size [BLOCK_M, BLOCK_N] or scalar.
536
+ off_z, off_h, offs_m, offs_n,
537
+ # Offsets needed for TMA loads
538
+ kv_start,
539
+ # blocksparse data
540
+ kv_indices, kv_num_blocks,
541
+ # start kv and end kv block
542
+ block_n_start, block_n_end,
543
+ MATMUL_PRECISION,
544
+ # Strides for K and V
545
+ stride_kk, stride_kn, stride_vn, stride_vk,
546
+ IS_FULL_BLOCKS,
547
+ ):
548
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
549
+ PRESCALE_QK : tl.constexpr = False
550
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
551
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
552
+ WRITE_DQ : tl.constexpr = True
553
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
554
+ OUTPUT_MAX : tl.constexpr = False
555
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
556
+ IS_DIVISIBLE : tl.constexpr = False
557
+ GQA_SHARED_HEADS : tl.constexpr = 4
558
+ HAS_FULL_BLOCKS : tl.constexpr = True
559
+ SM_SCALE : tl.constexpr = 0.08838834764831845
560
+ SPLIT_KV : tl.constexpr = 32
561
+ QK_HEAD_DIM : tl.constexpr = 128
562
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
563
+ V_HEAD_DIM : tl.constexpr = 128
564
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
565
+ SAFE_HEAD_DIM : tl.constexpr = True
566
+ BLOCK_M : tl.constexpr = 512
567
+ SAFE_M_BOUNDARY : tl.constexpr = False
568
+ SAFE_N_BOUNDARY : tl.constexpr = True
569
+ BLOCK_N : tl.constexpr = 64
570
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
571
+ USE_TMA : tl.constexpr = False
572
+ INDEX_DTYPE : tl.constexpr = tl.int32
573
+
574
+
575
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
576
+ RCP_LN2: tl.constexpr = 1.44269504
577
+
578
+ if PRESCALE_QK:
579
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
580
+
581
+ kv_offset = 0
582
+
583
+ # loop over k, v and update accumulator until block_n_end
584
+ for start_n in range(block_n_start, block_n_end):
585
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
586
+ if IS_DIVISIBLE:
587
+ acc, l_i, m_i = forward_block_mn(
588
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
589
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
590
+ # accumulated values
591
+ acc, l_i, m_i,
592
+ # Offsets
593
+ off_z, off_h, offs_m, offs_n,
594
+ # Offsets needed for TMA loads
595
+ kv_start,
596
+ kv_offset,
597
+ MATMUL_PRECISION, RCP_LN2,
598
+ # Strides for K and V
599
+ stride_kk, stride_kn, stride_vn, stride_vk,
600
+ IS_FULL_BLOCKS,
601
+ )
602
+ else:
603
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
604
+ # it's on par or slightly faster than only applying to the last block in fwd.
605
+ # However, we choose different strategy for bwd, where we only apply mod & mask
606
+ # to the last block because it's faster a lot.
607
+ acc, l_i, m_i = forward_block_mn(
608
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
609
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
610
+ # accumulated values
611
+ acc, l_i, m_i,
612
+ # Offsets
613
+ off_z, off_h, offs_m, offs_n,
614
+ # Offsets needed for TMA loads
615
+ kv_start,
616
+ kv_offset,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ # Strides for K and V
619
+ stride_kk, stride_kn, stride_vn, stride_vk,
620
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
621
+ )
622
+
623
+
624
+
625
+ offset = get_offset_for_next_block(
626
+ start_n, kv_indices, kv_num_blocks,
627
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
628
+ )
629
+
630
+ offs_n = offs_n + offset
631
+ kv_offset += offset
632
+
633
+
634
+ return acc, l_i, m_i
635
+ ''', device_str='cuda')
636
+
637
+
638
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6g/c6gb52skvqs7or57vd3zu5um3r5rnmeimd5qam27l5j7uqx7t4ai.py
639
+ # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul]
640
+ # Source node to ATen node mapping:
641
+ # flex_attention => flex_attention
642
+ # lse_scaled => mul_9
643
+ # Graph fragment:
644
+ # %buf3 : Tensor = PlaceHolder[target=buf3]
645
+ # %buf4 : Tensor = PlaceHolder[target=buf4]
646
+ # %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf5]
647
+ # %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf7]
648
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
649
+ # %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
650
+ # return %buf5,%buf7,%mul_9
651
+ triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', '''
652
+ import triton
653
+ import triton.language as tl
654
+
655
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
656
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
657
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
658
+ triton_helpers.set_driver_to_gpu()
659
+
660
+ @triton_heuristics.persistent_reduction(
661
+ size_hints={'x': 4096, 'r0_': 32},
662
+ reduction_hint=ReductionHint.DEFAULT,
663
+ filename=__file__,
664
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
665
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
666
+ )
667
+ @triton.jit
668
+ def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
669
+ r0_numel = 32
670
+ R0_BLOCK: tl.constexpr = 32
671
+ rnumel = r0_numel
672
+ RBLOCK: tl.constexpr = R0_BLOCK
673
+ xoffset = tl.program_id(0) * XBLOCK
674
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
675
+ xmask = xindex < xnumel
676
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
677
+ r0_offset = 0
678
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
679
+ roffset = r0_offset
680
+ rindex = r0_index
681
+ r0_1 = r0_index
682
+ x0 = xindex
683
+ x2 = (xindex % ks0)
684
+ x3 = triton_helpers.div_floor_integer(xindex, ks0)
685
+ tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
686
+ tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
687
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
688
+ tmp3 = tl.where(xmask, tmp1, float("-inf"))
689
+ tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
690
+ tmp6 = float("-inf")
691
+ tmp7 = tmp4 == tmp6
692
+ tmp8 = tmp0 - tmp4
693
+ tmp9 = 0.0
694
+ tmp10 = tl.where(tmp7, tmp9, tmp8)
695
+ tmp11 = libdevice.exp2(tmp10)
696
+ tmp12 = tmp5 * tmp11
697
+ tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
698
+ tmp15 = tl.where(xmask, tmp13, 0)
699
+ tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
700
+ tmp17 = 1.0
701
+ tmp18 = tl.where(tmp7, tmp17, tmp16)
702
+ tmp19 = libdevice.log2(tmp18)
703
+ tmp20 = tmp19 + tmp4
704
+ tmp21 = 0.6931471805599453
705
+ tmp22 = tmp20 * tmp21
706
+ tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
707
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
708
+ tl.store(out_ptr1 + (x0), tmp16, xmask)
709
+ ''', device_str='cuda')
710
+
711
+
712
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/jt/cjtngjzio5oudkq4n4xggwz5enmgujrff3ktfnon7oykgb7as5tu.py
713
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
714
+ # Source node to ATen node mapping:
715
+ # flex_attention => flex_attention, getitem
716
+ # Graph fragment:
717
+ # %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:3" = PlaceHolder[target=buf2]
718
+ # %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf5]
719
+ # %buf3 : Tensor = PlaceHolder[target=buf3]
720
+ # %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:3" = PlaceHolder[target=buf8]
721
+ # %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf7]
722
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
723
+ # %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {})
724
+ # return %buf8,%getitem
725
+ triton_per_fused_2 = async_compile.triton('triton_per_fused_2', '''
726
+ import triton
727
+ import triton.language as tl
728
+
729
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
730
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
731
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
732
+ triton_helpers.set_driver_to_gpu()
733
+
734
+ @triton_heuristics.persistent_reduction(
735
+ size_hints={'x': 524288, 'r0_': 32},
736
+ reduction_hint=ReductionHint.DEFAULT,
737
+ filename=__file__,
738
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
739
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
740
+ )
741
+ @triton.jit
742
+ def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr):
743
+ r0_numel = 32
744
+ R0_BLOCK: tl.constexpr = 32
745
+ rnumel = r0_numel
746
+ RBLOCK: tl.constexpr = R0_BLOCK
747
+ xoffset = tl.program_id(0) * XBLOCK
748
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
749
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
750
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
751
+ r0_offset = 0
752
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
753
+ roffset = r0_offset
754
+ rindex = r0_index
755
+ r0_2 = r0_index
756
+ x5 = xindex
757
+ x1 = xindex // 128
758
+ x0 = (xindex % 128)
759
+ x3 = ((xindex // 128) % ks0)
760
+ x4 = xindex // ks1
761
+ tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None)
762
+ tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
763
+ tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last')
764
+ tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last')
765
+ tmp2 = float("-inf")
766
+ tmp3 = tmp1 == tmp2
767
+ tmp5 = tmp4 - tmp1
768
+ tmp6 = 0.0
769
+ tmp7 = tl.where(tmp3, tmp6, tmp5)
770
+ tmp8 = libdevice.exp2(tmp7)
771
+ tmp9 = tmp0 * tmp8
772
+ tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
773
+ tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
774
+ tmp14 = 1.0
775
+ tmp15 = tl.where(tmp3, tmp14, tmp13)
776
+ tmp16 = (tmp12 / tmp15)
777
+ tmp17 = tmp16.to(tl.float32)
778
+ tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None)
779
+ ''', device_str='cuda')
780
+
781
+
782
+ async_compile.wait(globals())
783
+ del async_compile
784
+
785
+ class Runner:
786
+ def __init__(self, partitions):
787
+ self.partitions = partitions
788
+
789
+ def recursively_apply_fns(self, fns):
790
+ new_callables = []
791
+ for fn, c in zip(fns, self.partitions):
792
+ new_callables.append(fn(c))
793
+ self.partitions = new_callables
794
+
795
+ def call(self, args):
796
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args
797
+ args.clear()
798
+ s50 = arg0_1
799
+ s0 = arg2_1
800
+ s43 = arg4_1
801
+ s37 = arg7_1
802
+ s71 = arg8_1
803
+ assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
804
+ assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
805
+ assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1))
806
+ assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1))
807
+ assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1))
808
+ assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1))
809
+ assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1))
810
+ assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1))
811
+ assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1))
812
+ assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1))
813
+ assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1))
814
+ with torch.cuda._DeviceGuard(3):
815
+ torch.cuda.set_device(3)
816
+ buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32)
817
+ buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32)
818
+ buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32)
819
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
820
+ stream3 = get_raw_stream(3)
821
+ triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, 8, 32, 1, stream=stream3)
822
+ del arg10_1
823
+ del arg11_1
824
+ del arg1_1
825
+ del arg3_1
826
+ del arg5_1
827
+ del arg6_1
828
+ del arg9_1
829
+ buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32)
830
+ buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
831
+ buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
832
+ # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul]
833
+ triton_per_fused_mul_1_xnumel = 32*s37
834
+ stream3 = get_raw_stream(3)
835
+ triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream3)
836
+ del buf1
837
+ ps0 = 128*s37
838
+ buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
839
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
840
+ triton_per_fused_2_xnumel = 4096*s37
841
+ stream3 = get_raw_stream(3)
842
+ triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream3)
843
+ del buf0
844
+ del buf2
845
+ del buf5
846
+ del buf7
847
+ return (buf9, buf10, )
848
+
849
+ runner = Runner(partitions=[])
850
+ call = runner.call
851
+ recursively_apply_fns = runner.recursively_apply_fns
852
+
853
+
854
+ def benchmark_compiled_module(times=10, repeat=10):
855
+ from torch._dynamo.testing import rand_strided
856
+ from torch._inductor.utils import print_performance
857
+ arg0_1 = 96
858
+ arg1_1 = rand_strided((1, 32, 96, 128), (393216, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16)
859
+ arg2_1 = 96
860
+ arg3_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16)
861
+ arg4_1 = 96
862
+ arg5_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16)
863
+ arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
864
+ arg7_1 = 96
865
+ arg8_1 = 96
866
+ arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
867
+ arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
868
+ arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
869
+ arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
870
+ arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
871
+ arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
872
+ arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
873
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1])
874
+ return print_performance(fn, times=times, repeat=repeat)
875
+
876
+
877
+ if __name__ == "__main__":
878
+ from torch._inductor.wrapper_benchmark import compiled_module_main
879
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/2g/c2gswut4q57fp2ueybipg5qfqiy5coitofujwdnvqdwhr7nbvnyq.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
88
+
89
+ ZQ = 1
90
+ HQ = 32
91
+ Q_LEN = ks0
92
+ ZKV = 1
93
+ KV_LEN = ks1
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 1
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 1
130
+ stride_kv_idx_h = 1
131
+ stride_kv_idx_m = 1
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831845
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = (m)
371
+ tmp2 = tl.full([1], 0, tl.int32)
372
+ tmp3 = tmp1 < tmp2
373
+ tmp4 = (n)
374
+ tmp5 = tmp4 <= tmp1
375
+ tmp6 = tmp3 & tmp5
376
+ tmp7 = tmp1 >= tmp2
377
+ tmp8 = tmp4 < tmp2
378
+ tmp9 = tmp7 & tmp8
379
+ tmp10 = tmp8 == 0
380
+ tmp11 = tmp7 & tmp10
381
+ tmp12 = tmp1 - tmp2
382
+ tmp13 = tl.full([1], 16, tl.int32)
383
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
384
+ tmp15 = tmp4 - tmp2
385
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
386
+ tmp17 = tmp14 == tmp16
387
+ tmp18 = tmp11 & tmp17
388
+ tmp19 = tmp9 | tmp18
389
+ tmp20 = tmp6 | tmp19
390
+ mask_mod_output = tmp20
391
+
392
+
393
+ if CHECK_BLOCK_BOUNDARY:
394
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
395
+ # apply mask for partially unmasked blocks
396
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
397
+
398
+ if not PRESCALE_QK:
399
+ post_mod_scores *= RCP_LN2
400
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
401
+
402
+ # -- compute scaling constant ---
403
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
404
+ if not ROWS_GUARANTEED_SAFE:
405
+ masked_out_rows = (m_ij == float("-inf"))
406
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
407
+ else:
408
+ m_ij_masked = m_ij
409
+
410
+ alpha = tl.math.exp2(m_i - m_ij_masked)
411
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
412
+
413
+ # NB: l_i update is pulled up here since it's a bit faster
414
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
415
+ # m_ij
416
+ l_i = l_i * alpha + tl.sum(p, 1)
417
+ # # -- scale and update acc --
418
+ acc = acc * alpha[:, None]
419
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
420
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
421
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
422
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
423
+
424
+ # -- update m_i
425
+ m_i = m_ij
426
+
427
+ return acc, l_i, m_i
428
+
429
+ @triton.jit
430
+ def forward_inner(
431
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
432
+ q, K, V,
433
+ desc_k, desc_v, Q_LEN, KV_LEN,
434
+ # accumulated values
435
+ acc, l_i, m_i,
436
+ # Offsets used as inputs to score_mod & mask_mod
437
+ # of size [BLOCK_M, BLOCK_N] or scalar.
438
+ off_z, off_h, offs_m, offs_n,
439
+ # Offsets needed for TMA loads
440
+ kv_start,
441
+ # blocksparse data
442
+ kv_indices, kv_num_blocks,
443
+ # start kv and end kv block
444
+ block_n_start, block_n_end,
445
+ MATMUL_PRECISION,
446
+ # Strides for K and V
447
+ stride_kk, stride_kn, stride_vn, stride_vk,
448
+ IS_FULL_BLOCKS,
449
+ ):
450
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
451
+ PRESCALE_QK : tl.constexpr = False
452
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
453
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
454
+ WRITE_DQ : tl.constexpr = True
455
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
456
+ OUTPUT_MAX : tl.constexpr = False
457
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
458
+ IS_DIVISIBLE : tl.constexpr = False
459
+ SM_SCALE : tl.constexpr = 0.08838834764831845
460
+ GQA_SHARED_HEADS : tl.constexpr = 4
461
+ HAS_FULL_BLOCKS : tl.constexpr = True
462
+ QK_HEAD_DIM : tl.constexpr = 128
463
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
464
+ V_HEAD_DIM : tl.constexpr = 128
465
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
466
+ SAFE_HEAD_DIM : tl.constexpr = True
467
+ USE_TMA : tl.constexpr = False
468
+ BLOCK_M : tl.constexpr = 128
469
+ BLOCK_N : tl.constexpr = 64
470
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
471
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
472
+ INDEX_DTYPE : tl.constexpr = tl.int32
473
+
474
+
475
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
476
+ RCP_LN2: tl.constexpr = 1.44269504
477
+
478
+ if PRESCALE_QK:
479
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
480
+
481
+ kv_offset = 0
482
+
483
+ # loop over k, v and update accumulator until block_n_end
484
+ for start_n in range(block_n_start, block_n_end):
485
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
486
+ if IS_DIVISIBLE:
487
+ acc, l_i, m_i = forward_block_mn(
488
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
489
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
490
+ # accumulated values
491
+ acc, l_i, m_i,
492
+ # Offsets
493
+ off_z, off_h, offs_m, offs_n,
494
+ # Offsets needed for TMA loads
495
+ kv_start,
496
+ kv_offset,
497
+ MATMUL_PRECISION, RCP_LN2,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ )
502
+ else:
503
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
504
+ # it's on par or slightly faster than only applying to the last block in fwd.
505
+ # However, we choose different strategy for bwd, where we only apply mod & mask
506
+ # to the last block because it's faster a lot.
507
+ acc, l_i, m_i = forward_block_mn(
508
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
509
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
510
+ # accumulated values
511
+ acc, l_i, m_i,
512
+ # Offsets
513
+ off_z, off_h, offs_m, offs_n,
514
+ # Offsets needed for TMA loads
515
+ kv_start,
516
+ kv_offset,
517
+ MATMUL_PRECISION, RCP_LN2,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
521
+ )
522
+
523
+
524
+
525
+ offset = get_offset_for_next_block(
526
+ start_n, kv_indices, kv_num_blocks,
527
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
528
+ )
529
+
530
+ offs_n = offs_n + offset
531
+ kv_offset += offset
532
+
533
+
534
+ return acc, l_i, m_i
progress/SpecForge/cache/compiled_kernels/2j/4b74fa21eaaf86b6290185f6fe50aec9b905d858a087238ceddb52477f3f6acb.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"}
progress/SpecForge/cache/compiled_kernels/2j/c2j3mtk3thi6sn2hxiuhuigjw43spiu74mxdervpgpfrtos7u2qh.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 4096},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x2 = xindex
23
+ x0 = (xindex % ks0)
24
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
25
+ tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
26
+ tmp1 = 0.6931471805599453
27
+ tmp2 = tmp0 * tmp1
28
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
progress/SpecForge/cache/compiled_kernels/2n/c2ngvuchx6agpdr6v7awl3qgblaehfzaauoxn6camwvtk7syoxsk.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6u/c6uror2yjtc6vpcc3on3oq3lwi6yghlxrmwz5rocw5haxvfiz47e.py
38
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
39
+ # Source node to ATen node mapping:
40
+ # flex_attention => flex_attention
41
+ # Graph fragment:
42
+ # %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=arg1_1]
43
+ # %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg3_1]
44
+ # %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg5_1]
45
+ # %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=getitem_1]
46
+ # %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1]
47
+ # %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=arg9_1]
48
+ # %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=arg6_1]
49
+ # %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=arg10_1]
50
+ # %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=arg11_1]
51
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
52
+ # return %getitem
53
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
54
+ import triton
55
+ import triton.language as tl
56
+
57
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
58
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
59
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
60
+
61
+ @triton_heuristics.template(
62
+
63
+ num_stages=3,
64
+ num_warps=8,
65
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
67
+
68
+ )
69
+ @triton.jit
70
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
71
+ PRESCALE_QK : tl.constexpr = False
72
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
73
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
74
+ WRITE_DQ : tl.constexpr = True
75
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
76
+ OUTPUT_MAX : tl.constexpr = False
77
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
78
+ IS_DIVISIBLE : tl.constexpr = False
79
+ SM_SCALE : tl.constexpr = 0.08838834764831845
80
+ GQA_SHARED_HEADS : tl.constexpr = 4
81
+ HAS_FULL_BLOCKS : tl.constexpr = True
82
+ QK_HEAD_DIM : tl.constexpr = 128
83
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
84
+ V_HEAD_DIM : tl.constexpr = 128
85
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ SAFE_HEAD_DIM : tl.constexpr = True
87
+ USE_TMA : tl.constexpr = False
88
+ BLOCK_M : tl.constexpr = 128
89
+ BLOCK_N : tl.constexpr = 64
90
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
91
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
92
+ INDEX_DTYPE : tl.constexpr = tl.int32
93
+ Q = arg_Q
94
+ K = arg_K
95
+ V = arg_V
96
+ LSE = arg_LSE
97
+ MAX = arg_MAX
98
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
99
+ KV_IDX = arg_KV_IDX
100
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
101
+ FULL_KV_IDX = arg_FULL_KV_IDX
102
+
103
+ # Sub notation for this kernel:
104
+ #
105
+ # Q: Query, K: Key, V: Value
106
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
107
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
108
+ # V_HEAD_DIM: The dimension of the value embeddings
109
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
110
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
111
+ #
112
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
113
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
114
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
115
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
116
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
117
+ #
118
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
119
+ #
120
+ # (Modifiable) Performance tuning options
121
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
122
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
123
+
124
+ # The below are kernel options that can be applied for certain score_mods,
125
+ # or involve a numerics vs. perf tradeoff
126
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
127
+ # about 20% more numerical error, but slightly faster.
128
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
129
+ # is not masked out? If so, we can skip an extra safety check
130
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
131
+ # contiguous? If so, we don't need to do an indirect jump for every block
132
+
133
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
134
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
135
+
136
+ # Define strides of inputs
137
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
138
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
139
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
140
+
141
+ ZQ = 1
142
+ HQ = 32
143
+ Q_LEN = ks0
144
+ ZKV = 1
145
+ KV_LEN = ks1
146
+
147
+ MATMUL_PRECISION = Q.dtype.element_ty
148
+
149
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
150
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
151
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
152
+
153
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
154
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
155
+ off_zkv = off_zq % ZKV
156
+ off_hkv = off_hq // GQA_SHARED_HEADS
157
+ off_g = off_hq % GQA_SHARED_HEADS
158
+
159
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
160
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
161
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
162
+
163
+ Q = Q + q_offset
164
+ K = K + k_offset
165
+ V = V + v_offset
166
+
167
+ # Setting up the TMA descriptors for Q, K, V
168
+ desc_q = None
169
+ desc_k = None
170
+ desc_v = None
171
+
172
+ SPARSE_Z = 1
173
+ SPARSE_HQ = 1
174
+
175
+ sparse_idx_z = off_zq % SPARSE_Z
176
+ sparse_idx_hq = off_hq % SPARSE_HQ
177
+
178
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
179
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
180
+
181
+ stride_kv_num_blks_h = 1
182
+ stride_kv_idx_h = 1
183
+ stride_kv_idx_m = 1
184
+
185
+ # initialize pointer to m and l
186
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
187
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
188
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
189
+
190
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
191
+
192
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
193
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
194
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
195
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
196
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
197
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
198
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
199
+
200
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
201
+ # We don't know anything "special" about these blocks, so we need to apply
202
+ # both score_mod and mask_mod to it
203
+ kv_indices = KV_IDX + sparse_kv_idx_offset
204
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
205
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
206
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
207
+
208
+
209
+ # K and V pointers will be passed directly to forward_inner
210
+
211
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
212
+
213
+
214
+ acc, l_i, m_i = forward_inner(
215
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
216
+ q, K, V,
217
+ desc_k, desc_v, Q_LEN, KV_LEN,
218
+ acc, l_i, m_i,
219
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
220
+ kv_start,
221
+ kv_indices, kv_num_blocks,
222
+ 0, block_n_end,
223
+ MATMUL_PRECISION,
224
+ stride_kk, stride_kn, stride_vn, stride_vk,
225
+ IS_FULL_BLOCKS=False,
226
+ )
227
+
228
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
229
+ # We know these blocks are guaranteed to be "full", so we don't need to
230
+ # apply mask_mod to them - only score_mod
231
+ if HAS_FULL_BLOCKS:
232
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
233
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
234
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
235
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
236
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
237
+ # K and V pointers will be passed directly to forward_inner
238
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
239
+
240
+ acc, l_i, m_i = forward_inner(
241
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
242
+ q, K, V,
243
+ desc_k, desc_v, Q_LEN, KV_LEN,
244
+ acc, l_i, m_i,
245
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
246
+ kv_start,
247
+ kv_indices, kv_num_blocks,
248
+ 0, block_n_end,
249
+ MATMUL_PRECISION,
250
+ stride_kk, stride_kn, stride_vn, stride_vk,
251
+ IS_FULL_BLOCKS=True,
252
+ )
253
+
254
+
255
+ # [Note] Handle fully masked out rows:
256
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
257
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
258
+ l_i = tl.where(l_i == 0.0, 1, l_i)
259
+
260
+ acc = acc / l_i[:, None]
261
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
262
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
263
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
264
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
265
+
266
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
267
+
268
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
269
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
270
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
271
+
272
+ if OUTPUT_LOGSUMEXP:
273
+ off_hz = off_zq * HQ + off_hq
274
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
275
+ lse = m_i + tl.math.log2(l_i)
276
+ if IS_DIVISIBLE:
277
+ tl.store(l_ptrs, lse)
278
+ else:
279
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
280
+
281
+ if OUTPUT_MAX:
282
+ off_hz = off_zq * HQ + off_hq
283
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
284
+ if IS_DIVISIBLE:
285
+ tl.store(max_ptrs, m_i)
286
+ else:
287
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
288
+
289
+
290
+ # Utility triton funcs
291
+ @triton.jit
292
+ def get_offset_for_next_block(
293
+ loop_iter, col_indices, total_blocks,
294
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
295
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
296
+ ):
297
+ if BLOCKS_ARE_CONTIGUOUS:
298
+ return BLOCK
299
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
300
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
301
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
302
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
303
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
304
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
305
+ return offset
306
+
307
+ @triton.jit
308
+ def get_bounded_indices(indices, max_len=None):
309
+ return indices % max_len if max_len is not None else indices
310
+
311
+ @triton.jit
312
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
313
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
314
+ return tl.load(block_ptr)
315
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
317
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
319
+ else:
320
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
321
+
322
+ @triton.jit
323
+ def load_checked_2d(
324
+ ptr,
325
+ offs_m,
326
+ offs_n,
327
+ stride_m,
328
+ stride_n,
329
+ IS_DIVISIBLE_M: tl.constexpr,
330
+ IS_DIVISIBLE_N: tl.constexpr,
331
+ M_LEN: tl.constexpr,
332
+ N_LEN: tl.constexpr,
333
+ ):
334
+ # Calculate final pointer if strides are provided
335
+ if stride_m is not None and stride_n is not None:
336
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
337
+
338
+ # Handle all masking cases
339
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
340
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
341
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
343
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
345
+ else: # Both divisible
346
+ return tl.load(ptr)
347
+
348
+
349
+ # Common Imports
350
+ @triton.jit
351
+ def forward_block_mn(
352
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
353
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
354
+ # accumulated values
355
+ acc, l_i, m_i,
356
+ # Offsets
357
+ off_z, off_h, offs_m, offs_n,
358
+ # Offsets needed for TMA loads
359
+ kv_start,
360
+ kv_offset,
361
+ MATMUL_PRECISION, RCP_LN2,
362
+ # Strides for K and V
363
+ stride_kk, stride_kn, stride_vn, stride_vk,
364
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
365
+
366
+ ):
367
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
368
+ PRESCALE_QK : tl.constexpr = False
369
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
370
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
371
+ WRITE_DQ : tl.constexpr = True
372
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
373
+ OUTPUT_MAX : tl.constexpr = False
374
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
375
+ IS_DIVISIBLE : tl.constexpr = False
376
+ SM_SCALE : tl.constexpr = 0.08838834764831845
377
+ GQA_SHARED_HEADS : tl.constexpr = 4
378
+ HAS_FULL_BLOCKS : tl.constexpr = True
379
+ QK_HEAD_DIM : tl.constexpr = 128
380
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
381
+ V_HEAD_DIM : tl.constexpr = 128
382
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ SAFE_HEAD_DIM : tl.constexpr = True
384
+ USE_TMA : tl.constexpr = False
385
+ BLOCK_M : tl.constexpr = 128
386
+ BLOCK_N : tl.constexpr = 64
387
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
388
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
389
+ INDEX_DTYPE : tl.constexpr = tl.int32
390
+
391
+
392
+ # -- load k --
393
+ # NB reversed order to since K is transposed
394
+ kv_base_offset = kv_start + kv_offset
395
+
396
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
397
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
398
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
399
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
400
+
401
+ k = tl.trans(k)
402
+ # -- compute qk ---
403
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
404
+ if not PRESCALE_QK:
405
+ qk *= SM_SCALE
406
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
407
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
408
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
409
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
410
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
411
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
412
+
413
+ tmp0 = (qk)
414
+ post_mod_scores = tmp0
415
+
416
+
417
+ if CHECK_BLOCK_BOUNDARY:
418
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
419
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
420
+
421
+ if not IS_FULL_BLOCKS:
422
+ tmp1 = (m)
423
+ tmp2 = tl.full([1], 0, tl.int32)
424
+ tmp3 = tmp1 < tmp2
425
+ tmp4 = (n)
426
+ tmp5 = tmp4 <= tmp1
427
+ tmp6 = tmp3 & tmp5
428
+ tmp7 = tmp1 >= tmp2
429
+ tmp8 = tmp4 < tmp2
430
+ tmp9 = tmp7 & tmp8
431
+ tmp10 = tmp8 == 0
432
+ tmp11 = tmp7 & tmp10
433
+ tmp12 = tmp1 - tmp2
434
+ tmp13 = tl.full([1], 16, tl.int32)
435
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
436
+ tmp15 = tmp4 - tmp2
437
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
438
+ tmp17 = tmp14 == tmp16
439
+ tmp18 = tmp11 & tmp17
440
+ tmp19 = tmp9 | tmp18
441
+ tmp20 = tmp6 | tmp19
442
+ mask_mod_output = tmp20
443
+
444
+
445
+ if CHECK_BLOCK_BOUNDARY:
446
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
447
+ # apply mask for partially unmasked blocks
448
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
449
+
450
+ if not PRESCALE_QK:
451
+ post_mod_scores *= RCP_LN2
452
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
453
+
454
+ # -- compute scaling constant ---
455
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
456
+ if not ROWS_GUARANTEED_SAFE:
457
+ masked_out_rows = (m_ij == float("-inf"))
458
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
459
+ else:
460
+ m_ij_masked = m_ij
461
+
462
+ alpha = tl.math.exp2(m_i - m_ij_masked)
463
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
464
+
465
+ # NB: l_i update is pulled up here since it's a bit faster
466
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
467
+ # m_ij
468
+ l_i = l_i * alpha + tl.sum(p, 1)
469
+ # # -- scale and update acc --
470
+ acc = acc * alpha[:, None]
471
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
472
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
473
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
474
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
475
+
476
+ # -- update m_i
477
+ m_i = m_ij
478
+
479
+ return acc, l_i, m_i
480
+
481
+ @triton.jit
482
+ def forward_inner(
483
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
484
+ q, K, V,
485
+ desc_k, desc_v, Q_LEN, KV_LEN,
486
+ # accumulated values
487
+ acc, l_i, m_i,
488
+ # Offsets used as inputs to score_mod & mask_mod
489
+ # of size [BLOCK_M, BLOCK_N] or scalar.
490
+ off_z, off_h, offs_m, offs_n,
491
+ # Offsets needed for TMA loads
492
+ kv_start,
493
+ # blocksparse data
494
+ kv_indices, kv_num_blocks,
495
+ # start kv and end kv block
496
+ block_n_start, block_n_end,
497
+ MATMUL_PRECISION,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ ):
502
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
503
+ PRESCALE_QK : tl.constexpr = False
504
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
505
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
506
+ WRITE_DQ : tl.constexpr = True
507
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
508
+ OUTPUT_MAX : tl.constexpr = False
509
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
510
+ IS_DIVISIBLE : tl.constexpr = False
511
+ SM_SCALE : tl.constexpr = 0.08838834764831845
512
+ GQA_SHARED_HEADS : tl.constexpr = 4
513
+ HAS_FULL_BLOCKS : tl.constexpr = True
514
+ QK_HEAD_DIM : tl.constexpr = 128
515
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
516
+ V_HEAD_DIM : tl.constexpr = 128
517
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
518
+ SAFE_HEAD_DIM : tl.constexpr = True
519
+ USE_TMA : tl.constexpr = False
520
+ BLOCK_M : tl.constexpr = 128
521
+ BLOCK_N : tl.constexpr = 64
522
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
523
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
524
+ INDEX_DTYPE : tl.constexpr = tl.int32
525
+
526
+
527
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
528
+ RCP_LN2: tl.constexpr = 1.44269504
529
+
530
+ if PRESCALE_QK:
531
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
532
+
533
+ kv_offset = 0
534
+
535
+ # loop over k, v and update accumulator until block_n_end
536
+ for start_n in range(block_n_start, block_n_end):
537
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
538
+ if IS_DIVISIBLE:
539
+ acc, l_i, m_i = forward_block_mn(
540
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
541
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
542
+ # accumulated values
543
+ acc, l_i, m_i,
544
+ # Offsets
545
+ off_z, off_h, offs_m, offs_n,
546
+ # Offsets needed for TMA loads
547
+ kv_start,
548
+ kv_offset,
549
+ MATMUL_PRECISION, RCP_LN2,
550
+ # Strides for K and V
551
+ stride_kk, stride_kn, stride_vn, stride_vk,
552
+ IS_FULL_BLOCKS,
553
+ )
554
+ else:
555
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
556
+ # it's on par or slightly faster than only applying to the last block in fwd.
557
+ # However, we choose different strategy for bwd, where we only apply mod & mask
558
+ # to the last block because it's faster a lot.
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
573
+ )
574
+
575
+
576
+
577
+ offset = get_offset_for_next_block(
578
+ start_n, kv_indices, kv_num_blocks,
579
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
580
+ )
581
+
582
+ offs_n = offs_n + offset
583
+ kv_offset += offset
584
+
585
+
586
+ return acc, l_i, m_i
587
+ ''', device_str='cuda')
588
+
589
+
590
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sc/cscnwzzlpcjsqvndc4tlfwact2ecwdimqtwu2vya2cnto5t7c7pi.py
591
+ # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
592
+ # Source node to ATen node mapping:
593
+ # lse_scaled => mul_9
594
+ # Graph fragment:
595
+ # %buf3 : Tensor = PlaceHolder[target=buf3]
596
+ # %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
597
+ # return %mul_9
598
+ triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', '''
599
+ import triton
600
+ import triton.language as tl
601
+
602
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
603
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
604
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
605
+ triton_helpers.set_driver_to_gpu()
606
+
607
+ @triton_heuristics.pointwise(
608
+ size_hints={'x': 4096},
609
+ filename=__file__,
610
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
611
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
612
+ min_elem_per_thread=0
613
+ )
614
+ @triton.jit
615
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
616
+ xoffset = tl.program_id(0) * XBLOCK
617
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
618
+ xmask = xindex < xnumel
619
+ x2 = xindex
620
+ x0 = (xindex % ks0)
621
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
622
+ tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
623
+ tmp1 = 0.6931471805599453
624
+ tmp2 = tmp0 * tmp1
625
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
626
+ ''', device_str='cuda')
627
+
628
+
629
+ async_compile.wait(globals())
630
+ del async_compile
631
+
632
+ class Runner:
633
+ def __init__(self, partitions):
634
+ self.partitions = partitions
635
+
636
+ def recursively_apply_fns(self, fns):
637
+ new_callables = []
638
+ for fn, c in zip(fns, self.partitions):
639
+ new_callables.append(fn(c))
640
+ self.partitions = new_callables
641
+
642
+ def call(self, args):
643
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args
644
+ args.clear()
645
+ s50 = arg0_1
646
+ s0 = arg2_1
647
+ s43 = arg4_1
648
+ s37 = arg7_1
649
+ s71 = arg8_1
650
+ assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
651
+ assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
652
+ assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1))
653
+ assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1))
654
+ assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1))
655
+ assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1))
656
+ assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1))
657
+ assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1))
658
+ assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1))
659
+ assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1))
660
+ assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1))
661
+ with torch.cuda._DeviceGuard(2):
662
+ torch.cuda.set_device(2)
663
+ buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
664
+ buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
665
+ buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
666
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
667
+ stream2 = get_raw_stream(2)
668
+ triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream2)
669
+ del arg10_1
670
+ del arg11_1
671
+ del arg1_1
672
+ del arg3_1
673
+ del arg5_1
674
+ del arg6_1
675
+ del arg9_1
676
+ del buf1
677
+ buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
678
+ # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
679
+ triton_poi_fused_mul_1_xnumel = 32*s37
680
+ stream2 = get_raw_stream(2)
681
+ triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream2)
682
+ del buf0
683
+ return (buf2, buf5, )
684
+
685
+ runner = Runner(partitions=[])
686
+ call = runner.call
687
+ recursively_apply_fns = runner.recursively_apply_fns
688
+
689
+
690
+ def benchmark_compiled_module(times=10, repeat=10):
691
+ from torch._dynamo.testing import rand_strided
692
+ from torch._inductor.utils import print_performance
693
+ arg0_1 = 128
694
+ arg1_1 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16)
695
+ arg2_1 = 128
696
+ arg3_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16)
697
+ arg4_1 = 128
698
+ arg5_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16)
699
+ arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32)
700
+ arg7_1 = 128
701
+ arg8_1 = 128
702
+ arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32)
703
+ arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32)
704
+ arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32)
705
+ arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32)
706
+ arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32)
707
+ arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32)
708
+ arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32)
709
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1])
710
+ return print_performance(fn, times=times, repeat=repeat)
711
+
712
+
713
+ if __name__ == "__main__":
714
+ from torch._inductor.wrapper_benchmark import compiled_module_main
715
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/2n/c2nooi7ekpz4qvmvghggbegd5cyfspb27jmq2snbi26zbrpoibnx.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 4096, 'r0_': 128},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 128
20
+ R0_BLOCK: tl.constexpr = 128
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_2 = r0_index
32
+ x0 = (xindex % ks0)
33
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
34
+ x3 = xindex
35
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32)
36
+ tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32)
37
+ tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
38
+ tmp2 = tmp0 * tmp1
39
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
40
+ tmp5 = tl.where(xmask, tmp3, 0)
41
+ tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32)
42
+ tmp7 = tmp6.to(tl.float32)
43
+ tmp9 = 0.6931471805599453
44
+ tmp10 = tmp8 * tmp9
45
+ tmp11 = 1.4426950408889634
46
+ tmp12 = tmp10 * tmp11
47
+ tmp13 = tmp7 - tmp12
48
+ tl.store(out_ptr1 + (x3), tmp13, xmask)
progress/SpecForge/cache/compiled_kernels/2n/d17ff4e7bb44e5ae89a267ef332bb7c074804ce0942fc0694c3ef15b05f7854a.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 8, "num_warps": 8, "num_stages": 1, "configs_hash": "22b8c9e89632e6687ce26aaad980a76bbf5ee683fff317f3a6d7989c7528ff63", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "WJHIHLPATQZBKSQZSWJ5BD3ABYGFUF3YD6VF633RGCNWMMKVXCCA"}
progress/SpecForge/cache/compiled_kernels/2o/c2oashzxz74kzyuwo67tuhk32cike37ysabriftachdv7lf2qxgs.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
101
+
102
+ ZQ = 1
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = ks0
106
+ ZKV = 1
107
+ KV_LEN = ks1
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 1
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = ks2
148
+ stride_kv_idx_h = ks3*ks4
149
+ stride_kv_idx_m = ks4
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks5
245
+ stride_q_idx_h = ks6*ks7
246
+ stride_q_idx_n = ks6
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
345
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831845
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = ks0
385
+ KV_LEN = ks1
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831845
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = (m)
480
+ tmp2 = tl.full([1], 0, tl.int32)
481
+ tmp3 = tmp1 < tmp2
482
+ tmp4 = (n)
483
+ tmp5 = tmp4 <= tmp1
484
+ tmp6 = tmp3 & tmp5
485
+ tmp7 = tmp1 >= tmp2
486
+ tmp8 = tmp4 < tmp2
487
+ tmp9 = tmp7 & tmp8
488
+ tmp10 = tmp8 == 0
489
+ tmp11 = tmp7 & tmp10
490
+ tmp12 = tmp1 - tmp2
491
+ tmp13 = tl.full([1], 16, tl.int32)
492
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
493
+ tmp15 = tmp4 - tmp2
494
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
495
+ tmp17 = tmp14 == tmp16
496
+ tmp18 = tmp11 & tmp17
497
+ tmp19 = tmp9 | tmp18
498
+ tmp20 = tmp6 | tmp19
499
+ mask_mod_output = tmp20
500
+
501
+
502
+ # apply mask for partial masked block
503
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
504
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
505
+ if not PRESCALE_QK:
506
+ post_mod_scores *= RCP_LN2
507
+ p = tl.math.exp2(post_mod_scores - lse)
508
+ # Compute dP and dS.
509
+ # NB reversed order to since V is transposed
510
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
511
+
512
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
513
+ ds = p * (dp - Di[:, None])
514
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
515
+ tmp21 = (ds)
516
+ grad_scores = tmp21
517
+
518
+
519
+ if not IS_DIVISIBLE:
520
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
521
+
522
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
523
+ if WRITE_DQ:
524
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
525
+
526
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
527
+ ds = grad_scores
528
+
529
+ if not IS_FULL_BLOCKS:
530
+ # (grads) apply mask for partially unmasked block
531
+ ds = tl.where(mask_mod_output, ds, 0.0)
532
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
533
+ ds = ds.to(MATMUL_PRECISION)
534
+ # Compute dQ.
535
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
536
+
537
+ return dq
538
+
539
+
540
+ @triton.jit
541
+ def bwd_dkdv_inner(
542
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
543
+ Q, DO, DELTA, LSE, # pointers
544
+ dk, dv, k, v,
545
+ off_z, off_hq, offs_n1, offs_m1,
546
+ stride_qm, stride_qd, stride_dom, stride_dod,
547
+ q_indices, sparse_q_num_blocks,
548
+ MATMUL_PRECISION,
549
+ IS_FULL_BLOCKS,
550
+ ):
551
+ PRESCALE_QK : tl.constexpr = False
552
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
553
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
554
+ WRITE_DQ : tl.constexpr = True
555
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
556
+ OUTPUT_MAX : tl.constexpr = False
557
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
558
+ IS_DIVISIBLE : tl.constexpr = False
559
+ SM_SCALE : tl.constexpr = 0.08838834764831845
560
+ GQA_SHARED_HEADS : tl.constexpr = 4
561
+ HAS_FULL_BLOCKS : tl.constexpr = True
562
+ QK_HEAD_DIM : tl.constexpr = 128
563
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
564
+ V_HEAD_DIM : tl.constexpr = 128
565
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
566
+ SAFE_HEAD_DIM : tl.constexpr = True
567
+ BLOCK_M1 : tl.constexpr = 64
568
+ BLOCK_N1 : tl.constexpr = 128
569
+ BLOCK_M2 : tl.constexpr = 128
570
+ BLOCK_N2 : tl.constexpr = 64
571
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
572
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
573
+ INDEX_DTYPE : tl.constexpr = tl.int32
574
+
575
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
576
+ RCP_LN2: tl.constexpr = 1.44269504
577
+ Q_LEN = ks0
578
+ KV_LEN = ks1
579
+
580
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
581
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
582
+
583
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
584
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
585
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
586
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
587
+
588
+ # The minimum is needed to handle the case where we run with a super large
589
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
590
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
591
+
592
+ for start_m in range(0, hi):
593
+ dk, dv = bwd_dkdv_block_mn(
594
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
595
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
596
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
597
+ stride_qm, stride_qd, stride_dom, stride_dod,
598
+ q_indices, sparse_q_num_blocks,
599
+ MATMUL_PRECISION, RCP_LN2,
600
+ IS_FULL_BLOCKS,
601
+ )
602
+ # Increment pointers.
603
+ offset = get_offset_for_next_block(
604
+ start_m, q_indices, sparse_q_num_blocks,
605
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
606
+ )
607
+
608
+ qT_ptrs += offset * stride_qm
609
+ do_ptrs += offset * stride_dom
610
+ offs_m1 += offset
611
+
612
+ return dk, dv
613
+
614
+
615
+ @triton.jit
616
+ def bwd_dkdv_block_mn(
617
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
618
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
619
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
620
+ stride_qm, stride_qd, stride_dom, stride_dod,
621
+ q_indices, sparse_q_num_blocks,
622
+ MATMUL_PRECISION, RCP_LN2,
623
+ IS_FULL_BLOCKS,
624
+ ):
625
+ PRESCALE_QK : tl.constexpr = False
626
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
627
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
628
+ WRITE_DQ : tl.constexpr = True
629
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
630
+ OUTPUT_MAX : tl.constexpr = False
631
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
632
+ IS_DIVISIBLE : tl.constexpr = False
633
+ SM_SCALE : tl.constexpr = 0.08838834764831845
634
+ GQA_SHARED_HEADS : tl.constexpr = 4
635
+ HAS_FULL_BLOCKS : tl.constexpr = True
636
+ QK_HEAD_DIM : tl.constexpr = 128
637
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
638
+ V_HEAD_DIM : tl.constexpr = 128
639
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
640
+ SAFE_HEAD_DIM : tl.constexpr = True
641
+ BLOCK_M1 : tl.constexpr = 64
642
+ BLOCK_N1 : tl.constexpr = 128
643
+ BLOCK_M2 : tl.constexpr = 128
644
+ BLOCK_N2 : tl.constexpr = 64
645
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
646
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
647
+ INDEX_DTYPE : tl.constexpr = tl.int32
648
+
649
+
650
+ # NB reversed order since Q is transposed
651
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
652
+ # Load LSE before computing qk to reduce pipeline stall.
653
+ if IS_DIVISIBLE:
654
+ lse = tl.load(LSE + offs_m1)
655
+ else:
656
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
657
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
658
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
659
+ if not PRESCALE_QK:
660
+ qkT *= SM_SCALE
661
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
662
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
663
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
664
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
665
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
666
+
667
+ pre_mod_scores = qkT
668
+ tmp22 = (qkT)
669
+ post_mod_scores = tmp22
670
+
671
+
672
+
673
+ if not IS_DIVISIBLE:
674
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
675
+
676
+ if not IS_FULL_BLOCKS:
677
+ tmp23 = (m)
678
+ tmp24 = tl.full([1], 0, tl.int32)
679
+ tmp25 = tmp23 < tmp24
680
+ tmp26 = (n)
681
+ tmp27 = tmp26 <= tmp23
682
+ tmp28 = tmp25 & tmp27
683
+ tmp29 = tmp23 >= tmp24
684
+ tmp30 = tmp26 < tmp24
685
+ tmp31 = tmp29 & tmp30
686
+ tmp32 = tmp30 == 0
687
+ tmp33 = tmp29 & tmp32
688
+ tmp34 = tmp23 - tmp24
689
+ tmp35 = tl.full([1], 16, tl.int32)
690
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
691
+ tmp37 = tmp26 - tmp24
692
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
693
+ tmp39 = tmp36 == tmp38
694
+ tmp40 = tmp33 & tmp39
695
+ tmp41 = tmp31 | tmp40
696
+ tmp42 = tmp28 | tmp41
697
+ mask_mod_output = tmp42
698
+
699
+ # (grads) apply mask for fully masked block
700
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
701
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
702
+ if not PRESCALE_QK:
703
+ post_mod_scores *= RCP_LN2
704
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
705
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
706
+ # Compute dV.
707
+ ppT = pT
708
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
709
+ if IS_DIVISIBLE:
710
+ Di = tl.load(DELTA + offs_m1)
711
+ else:
712
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
713
+ # Compute dP and dS.
714
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
715
+ dsT = pT * (dpT - Di[None, :])
716
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
717
+ tmp43 = (dsT)
718
+ grad_scores = tmp43
719
+
720
+
721
+
722
+ if not IS_DIVISIBLE:
723
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
724
+
725
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
726
+ if not WRITE_DQ:
727
+ idx_b = off_z
728
+ idx_h = off_hq
729
+ idx_m = m
730
+ idx_n = n
731
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
732
+
733
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
734
+ dsT = grad_scores
735
+ if not IS_FULL_BLOCKS:
736
+ # (grads) apply mask for partially unmasked block
737
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
738
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
739
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
740
+
741
+ return dk, dv
742
+
743
+ # Utility triton funcs
744
+ @triton.jit
745
+ def get_offset_for_next_block(
746
+ loop_iter, col_indices, total_blocks,
747
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
748
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
749
+ ):
750
+ if BLOCKS_ARE_CONTIGUOUS:
751
+ return BLOCK
752
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
753
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
754
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
755
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
756
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
757
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
758
+ return offset
759
+
760
+ @triton.jit
761
+ def get_bounded_indices(indices, max_len=None):
762
+ return indices % max_len if max_len is not None else indices
763
+
764
+ @triton.jit
765
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
766
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
767
+ return tl.load(block_ptr)
768
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
769
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
770
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
771
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
772
+ else:
773
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
774
+
775
+ @triton.jit
776
+ def load_checked_2d(
777
+ ptr,
778
+ offs_m,
779
+ offs_n,
780
+ stride_m,
781
+ stride_n,
782
+ IS_DIVISIBLE_M: tl.constexpr,
783
+ IS_DIVISIBLE_N: tl.constexpr,
784
+ M_LEN: tl.constexpr,
785
+ N_LEN: tl.constexpr,
786
+ ):
787
+ # Calculate final pointer if strides are provided
788
+ if stride_m is not None and stride_n is not None:
789
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
790
+
791
+ # Handle all masking cases
792
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
793
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
794
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
795
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
796
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
797
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
798
+ else: # Both divisible
799
+ return tl.load(ptr)
progress/SpecForge/cache/compiled_kernels/2v/c2vob47d7sxpitzmofyr55f5hvxsitxjhpyv5hdiqcdjgbwmxk76.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
101
+
102
+ ZQ = 1
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = ks0
106
+ ZKV = 1
107
+ KV_LEN = ks1
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 1
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 1
148
+ stride_kv_idx_h = 1
149
+ stride_kv_idx_m = 1
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 1
245
+ stride_q_idx_h = 1
246
+ stride_q_idx_n = 1
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
345
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831845
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = ks0
385
+ KV_LEN = ks1
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831845
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = (m)
480
+ tmp2 = tl.full([1], 0, tl.int32)
481
+ tmp3 = tmp1 < tmp2
482
+ tmp4 = (n)
483
+ tmp5 = tmp4 <= tmp1
484
+ tmp6 = tmp3 & tmp5
485
+ tmp7 = tmp1 >= tmp2
486
+ tmp8 = tmp4 < tmp2
487
+ tmp9 = tmp7 & tmp8
488
+ tmp10 = tmp8 == 0
489
+ tmp11 = tmp7 & tmp10
490
+ tmp12 = tmp1 - tmp2
491
+ tmp13 = tl.full([1], 16, tl.int32)
492
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
493
+ tmp15 = tmp4 - tmp2
494
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
495
+ tmp17 = tmp14 == tmp16
496
+ tmp18 = tmp11 & tmp17
497
+ tmp19 = tmp9 | tmp18
498
+ tmp20 = tmp6 | tmp19
499
+ mask_mod_output = tmp20
500
+
501
+
502
+ # apply mask for partial masked block
503
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
504
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
505
+ if not PRESCALE_QK:
506
+ post_mod_scores *= RCP_LN2
507
+ p = tl.math.exp2(post_mod_scores - lse)
508
+ # Compute dP and dS.
509
+ # NB reversed order to since V is transposed
510
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
511
+
512
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
513
+ ds = p * (dp - Di[:, None])
514
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
515
+ tmp21 = (ds)
516
+ grad_scores = tmp21
517
+
518
+
519
+ if not IS_DIVISIBLE:
520
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
521
+
522
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
523
+ if WRITE_DQ:
524
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
525
+
526
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
527
+ ds = grad_scores
528
+
529
+ if not IS_FULL_BLOCKS:
530
+ # (grads) apply mask for partially unmasked block
531
+ ds = tl.where(mask_mod_output, ds, 0.0)
532
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
533
+ ds = ds.to(MATMUL_PRECISION)
534
+ # Compute dQ.
535
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
536
+
537
+ return dq
538
+
539
+
540
+ @triton.jit
541
+ def bwd_dkdv_inner(
542
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
543
+ Q, DO, DELTA, LSE, # pointers
544
+ dk, dv, k, v,
545
+ off_z, off_hq, offs_n1, offs_m1,
546
+ stride_qm, stride_qd, stride_dom, stride_dod,
547
+ q_indices, sparse_q_num_blocks,
548
+ MATMUL_PRECISION,
549
+ IS_FULL_BLOCKS,
550
+ ):
551
+ PRESCALE_QK : tl.constexpr = False
552
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
553
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
554
+ WRITE_DQ : tl.constexpr = True
555
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
556
+ OUTPUT_MAX : tl.constexpr = False
557
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
558
+ IS_DIVISIBLE : tl.constexpr = False
559
+ SM_SCALE : tl.constexpr = 0.08838834764831845
560
+ GQA_SHARED_HEADS : tl.constexpr = 4
561
+ HAS_FULL_BLOCKS : tl.constexpr = True
562
+ QK_HEAD_DIM : tl.constexpr = 128
563
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
564
+ V_HEAD_DIM : tl.constexpr = 128
565
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
566
+ SAFE_HEAD_DIM : tl.constexpr = True
567
+ BLOCK_M1 : tl.constexpr = 64
568
+ BLOCK_N1 : tl.constexpr = 128
569
+ BLOCK_M2 : tl.constexpr = 128
570
+ BLOCK_N2 : tl.constexpr = 64
571
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
572
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
573
+ INDEX_DTYPE : tl.constexpr = tl.int32
574
+
575
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
576
+ RCP_LN2: tl.constexpr = 1.44269504
577
+ Q_LEN = ks0
578
+ KV_LEN = ks1
579
+
580
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
581
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
582
+
583
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
584
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
585
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
586
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
587
+
588
+ # The minimum is needed to handle the case where we run with a super large
589
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
590
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
591
+
592
+ for start_m in range(0, hi):
593
+ dk, dv = bwd_dkdv_block_mn(
594
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
595
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
596
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
597
+ stride_qm, stride_qd, stride_dom, stride_dod,
598
+ q_indices, sparse_q_num_blocks,
599
+ MATMUL_PRECISION, RCP_LN2,
600
+ IS_FULL_BLOCKS,
601
+ )
602
+ # Increment pointers.
603
+ offset = get_offset_for_next_block(
604
+ start_m, q_indices, sparse_q_num_blocks,
605
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
606
+ )
607
+
608
+ qT_ptrs += offset * stride_qm
609
+ do_ptrs += offset * stride_dom
610
+ offs_m1 += offset
611
+
612
+ return dk, dv
613
+
614
+
615
+ @triton.jit
616
+ def bwd_dkdv_block_mn(
617
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
618
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
619
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
620
+ stride_qm, stride_qd, stride_dom, stride_dod,
621
+ q_indices, sparse_q_num_blocks,
622
+ MATMUL_PRECISION, RCP_LN2,
623
+ IS_FULL_BLOCKS,
624
+ ):
625
+ PRESCALE_QK : tl.constexpr = False
626
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
627
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
628
+ WRITE_DQ : tl.constexpr = True
629
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
630
+ OUTPUT_MAX : tl.constexpr = False
631
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
632
+ IS_DIVISIBLE : tl.constexpr = False
633
+ SM_SCALE : tl.constexpr = 0.08838834764831845
634
+ GQA_SHARED_HEADS : tl.constexpr = 4
635
+ HAS_FULL_BLOCKS : tl.constexpr = True
636
+ QK_HEAD_DIM : tl.constexpr = 128
637
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
638
+ V_HEAD_DIM : tl.constexpr = 128
639
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
640
+ SAFE_HEAD_DIM : tl.constexpr = True
641
+ BLOCK_M1 : tl.constexpr = 64
642
+ BLOCK_N1 : tl.constexpr = 128
643
+ BLOCK_M2 : tl.constexpr = 128
644
+ BLOCK_N2 : tl.constexpr = 64
645
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
646
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
647
+ INDEX_DTYPE : tl.constexpr = tl.int32
648
+
649
+
650
+ # NB reversed order since Q is transposed
651
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
652
+ # Load LSE before computing qk to reduce pipeline stall.
653
+ if IS_DIVISIBLE:
654
+ lse = tl.load(LSE + offs_m1)
655
+ else:
656
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
657
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
658
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
659
+ if not PRESCALE_QK:
660
+ qkT *= SM_SCALE
661
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
662
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
663
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
664
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
665
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
666
+
667
+ pre_mod_scores = qkT
668
+ tmp22 = (qkT)
669
+ post_mod_scores = tmp22
670
+
671
+
672
+
673
+ if not IS_DIVISIBLE:
674
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
675
+
676
+ if not IS_FULL_BLOCKS:
677
+ tmp23 = (m)
678
+ tmp24 = tl.full([1], 0, tl.int32)
679
+ tmp25 = tmp23 < tmp24
680
+ tmp26 = (n)
681
+ tmp27 = tmp26 <= tmp23
682
+ tmp28 = tmp25 & tmp27
683
+ tmp29 = tmp23 >= tmp24
684
+ tmp30 = tmp26 < tmp24
685
+ tmp31 = tmp29 & tmp30
686
+ tmp32 = tmp30 == 0
687
+ tmp33 = tmp29 & tmp32
688
+ tmp34 = tmp23 - tmp24
689
+ tmp35 = tl.full([1], 16, tl.int32)
690
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
691
+ tmp37 = tmp26 - tmp24
692
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
693
+ tmp39 = tmp36 == tmp38
694
+ tmp40 = tmp33 & tmp39
695
+ tmp41 = tmp31 | tmp40
696
+ tmp42 = tmp28 | tmp41
697
+ mask_mod_output = tmp42
698
+
699
+ # (grads) apply mask for fully masked block
700
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
701
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
702
+ if not PRESCALE_QK:
703
+ post_mod_scores *= RCP_LN2
704
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
705
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
706
+ # Compute dV.
707
+ ppT = pT
708
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
709
+ if IS_DIVISIBLE:
710
+ Di = tl.load(DELTA + offs_m1)
711
+ else:
712
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
713
+ # Compute dP and dS.
714
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
715
+ dsT = pT * (dpT - Di[None, :])
716
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
717
+ tmp43 = (dsT)
718
+ grad_scores = tmp43
719
+
720
+
721
+
722
+ if not IS_DIVISIBLE:
723
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
724
+
725
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
726
+ if not WRITE_DQ:
727
+ idx_b = off_z
728
+ idx_h = off_hq
729
+ idx_m = m
730
+ idx_n = n
731
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
732
+
733
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
734
+ dsT = grad_scores
735
+ if not IS_FULL_BLOCKS:
736
+ # (grads) apply mask for partially unmasked block
737
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
738
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
739
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
740
+
741
+ return dk, dv
742
+
743
+ # Utility triton funcs
744
+ @triton.jit
745
+ def get_offset_for_next_block(
746
+ loop_iter, col_indices, total_blocks,
747
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
748
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
749
+ ):
750
+ if BLOCKS_ARE_CONTIGUOUS:
751
+ return BLOCK
752
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
753
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
754
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
755
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
756
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
757
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
758
+ return offset
759
+
760
+ @triton.jit
761
+ def get_bounded_indices(indices, max_len=None):
762
+ return indices % max_len if max_len is not None else indices
763
+
764
+ @triton.jit
765
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
766
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
767
+ return tl.load(block_ptr)
768
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
769
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
770
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
771
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
772
+ else:
773
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
774
+
775
+ @triton.jit
776
+ def load_checked_2d(
777
+ ptr,
778
+ offs_m,
779
+ offs_n,
780
+ stride_m,
781
+ stride_n,
782
+ IS_DIVISIBLE_M: tl.constexpr,
783
+ IS_DIVISIBLE_N: tl.constexpr,
784
+ M_LEN: tl.constexpr,
785
+ N_LEN: tl.constexpr,
786
+ ):
787
+ # Calculate final pointer if strides are provided
788
+ if stride_m is not None and stride_n is not None:
789
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
790
+
791
+ # Handle all masking cases
792
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
793
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
794
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
795
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
796
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
797
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
798
+ else: # Both divisible
799
+ return tl.load(ptr)
progress/SpecForge/cache/compiled_kernels/2y/c2yhndikcsebqfmbw7l44gmcdoyw7ogaqt7quyeygz3mp5w6u6ke.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6n/c6n4rf57opno6rcuedu4jk4etcok4ti2tlaztx2ht3z5eydc3vae.py
38
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
39
+ # Source node to ATen node mapping:
40
+ # flex_attention => flex_attention
41
+ # Graph fragment:
42
+ # %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=arg1_1]
43
+ # %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg3_1]
44
+ # %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg5_1]
45
+ # %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=getitem_1]
46
+ # %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf1]
47
+ # %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:1" = PlaceHolder[target=arg9_1]
48
+ # %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:1" = PlaceHolder[target=arg6_1]
49
+ # %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:1" = PlaceHolder[target=arg10_1]
50
+ # %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:1" = PlaceHolder[target=arg11_1]
51
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
52
+ # return %getitem
53
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
54
+ import triton
55
+ import triton.language as tl
56
+
57
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
58
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
59
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
60
+
61
+ @triton_heuristics.template(
62
+
63
+ num_stages=3,
64
+ num_warps=8,
65
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
67
+
68
+ )
69
+ @triton.jit
70
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
71
+ PRESCALE_QK : tl.constexpr = False
72
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
73
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
74
+ WRITE_DQ : tl.constexpr = True
75
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
76
+ OUTPUT_MAX : tl.constexpr = False
77
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
78
+ IS_DIVISIBLE : tl.constexpr = False
79
+ SM_SCALE : tl.constexpr = 0.08838834764831845
80
+ GQA_SHARED_HEADS : tl.constexpr = 4
81
+ HAS_FULL_BLOCKS : tl.constexpr = True
82
+ QK_HEAD_DIM : tl.constexpr = 128
83
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
84
+ V_HEAD_DIM : tl.constexpr = 128
85
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ SAFE_HEAD_DIM : tl.constexpr = True
87
+ USE_TMA : tl.constexpr = False
88
+ BLOCK_M : tl.constexpr = 128
89
+ BLOCK_N : tl.constexpr = 64
90
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
91
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
92
+ INDEX_DTYPE : tl.constexpr = tl.int32
93
+ Q = arg_Q
94
+ K = arg_K
95
+ V = arg_V
96
+ LSE = arg_LSE
97
+ MAX = arg_MAX
98
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
99
+ KV_IDX = arg_KV_IDX
100
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
101
+ FULL_KV_IDX = arg_FULL_KV_IDX
102
+
103
+ # Sub notation for this kernel:
104
+ #
105
+ # Q: Query, K: Key, V: Value
106
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
107
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
108
+ # V_HEAD_DIM: The dimension of the value embeddings
109
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
110
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
111
+ #
112
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
113
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
114
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
115
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
116
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
117
+ #
118
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
119
+ #
120
+ # (Modifiable) Performance tuning options
121
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
122
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
123
+
124
+ # The below are kernel options that can be applied for certain score_mods,
125
+ # or involve a numerics vs. perf tradeoff
126
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
127
+ # about 20% more numerical error, but slightly faster.
128
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
129
+ # is not masked out? If so, we can skip an extra safety check
130
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
131
+ # contiguous? If so, we don't need to do an indirect jump for every block
132
+
133
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
134
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
135
+
136
+ # Define strides of inputs
137
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
138
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
139
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
140
+
141
+ ZQ = 1
142
+ HQ = 32
143
+ Q_LEN = ks0
144
+ ZKV = 1
145
+ KV_LEN = ks1
146
+
147
+ MATMUL_PRECISION = Q.dtype.element_ty
148
+
149
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
150
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
151
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
152
+
153
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
154
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
155
+ off_zkv = off_zq % ZKV
156
+ off_hkv = off_hq // GQA_SHARED_HEADS
157
+ off_g = off_hq % GQA_SHARED_HEADS
158
+
159
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
160
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
161
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
162
+
163
+ Q = Q + q_offset
164
+ K = K + k_offset
165
+ V = V + v_offset
166
+
167
+ # Setting up the TMA descriptors for Q, K, V
168
+ desc_q = None
169
+ desc_k = None
170
+ desc_v = None
171
+
172
+ SPARSE_Z = 1
173
+ SPARSE_HQ = 1
174
+
175
+ sparse_idx_z = off_zq % SPARSE_Z
176
+ sparse_idx_hq = off_hq % SPARSE_HQ
177
+
178
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
179
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
180
+
181
+ stride_kv_num_blks_h = 1
182
+ stride_kv_idx_h = 1
183
+ stride_kv_idx_m = 1
184
+
185
+ # initialize pointer to m and l
186
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
187
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
188
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
189
+
190
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
191
+
192
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
193
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
194
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
195
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
196
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
197
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
198
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
199
+
200
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
201
+ # We don't know anything "special" about these blocks, so we need to apply
202
+ # both score_mod and mask_mod to it
203
+ kv_indices = KV_IDX + sparse_kv_idx_offset
204
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
205
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
206
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
207
+
208
+
209
+ # K and V pointers will be passed directly to forward_inner
210
+
211
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
212
+
213
+
214
+ acc, l_i, m_i = forward_inner(
215
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
216
+ q, K, V,
217
+ desc_k, desc_v, Q_LEN, KV_LEN,
218
+ acc, l_i, m_i,
219
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
220
+ kv_start,
221
+ kv_indices, kv_num_blocks,
222
+ 0, block_n_end,
223
+ MATMUL_PRECISION,
224
+ stride_kk, stride_kn, stride_vn, stride_vk,
225
+ IS_FULL_BLOCKS=False,
226
+ )
227
+
228
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
229
+ # We know these blocks are guaranteed to be "full", so we don't need to
230
+ # apply mask_mod to them - only score_mod
231
+ if HAS_FULL_BLOCKS:
232
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
233
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
234
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
235
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
236
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
237
+ # K and V pointers will be passed directly to forward_inner
238
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
239
+
240
+ acc, l_i, m_i = forward_inner(
241
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
242
+ q, K, V,
243
+ desc_k, desc_v, Q_LEN, KV_LEN,
244
+ acc, l_i, m_i,
245
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
246
+ kv_start,
247
+ kv_indices, kv_num_blocks,
248
+ 0, block_n_end,
249
+ MATMUL_PRECISION,
250
+ stride_kk, stride_kn, stride_vn, stride_vk,
251
+ IS_FULL_BLOCKS=True,
252
+ )
253
+
254
+
255
+ # [Note] Handle fully masked out rows:
256
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
257
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
258
+ l_i = tl.where(l_i == 0.0, 1, l_i)
259
+
260
+ acc = acc / l_i[:, None]
261
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
262
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
263
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
264
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
265
+
266
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
267
+
268
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
269
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
270
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
271
+
272
+ if OUTPUT_LOGSUMEXP:
273
+ off_hz = off_zq * HQ + off_hq
274
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
275
+ lse = m_i + tl.math.log2(l_i)
276
+ if IS_DIVISIBLE:
277
+ tl.store(l_ptrs, lse)
278
+ else:
279
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
280
+
281
+ if OUTPUT_MAX:
282
+ off_hz = off_zq * HQ + off_hq
283
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
284
+ if IS_DIVISIBLE:
285
+ tl.store(max_ptrs, m_i)
286
+ else:
287
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
288
+
289
+
290
+ # Utility triton funcs
291
+ @triton.jit
292
+ def get_offset_for_next_block(
293
+ loop_iter, col_indices, total_blocks,
294
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
295
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
296
+ ):
297
+ if BLOCKS_ARE_CONTIGUOUS:
298
+ return BLOCK
299
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
300
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
301
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
302
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
303
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
304
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
305
+ return offset
306
+
307
+ @triton.jit
308
+ def get_bounded_indices(indices, max_len=None):
309
+ return indices % max_len if max_len is not None else indices
310
+
311
+ @triton.jit
312
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
313
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
314
+ return tl.load(block_ptr)
315
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
317
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
319
+ else:
320
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
321
+
322
+ @triton.jit
323
+ def load_checked_2d(
324
+ ptr,
325
+ offs_m,
326
+ offs_n,
327
+ stride_m,
328
+ stride_n,
329
+ IS_DIVISIBLE_M: tl.constexpr,
330
+ IS_DIVISIBLE_N: tl.constexpr,
331
+ M_LEN: tl.constexpr,
332
+ N_LEN: tl.constexpr,
333
+ ):
334
+ # Calculate final pointer if strides are provided
335
+ if stride_m is not None and stride_n is not None:
336
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
337
+
338
+ # Handle all masking cases
339
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
340
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
341
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
343
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
345
+ else: # Both divisible
346
+ return tl.load(ptr)
347
+
348
+
349
+ # Common Imports
350
+ @triton.jit
351
+ def forward_block_mn(
352
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
353
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
354
+ # accumulated values
355
+ acc, l_i, m_i,
356
+ # Offsets
357
+ off_z, off_h, offs_m, offs_n,
358
+ # Offsets needed for TMA loads
359
+ kv_start,
360
+ kv_offset,
361
+ MATMUL_PRECISION, RCP_LN2,
362
+ # Strides for K and V
363
+ stride_kk, stride_kn, stride_vn, stride_vk,
364
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
365
+
366
+ ):
367
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
368
+ PRESCALE_QK : tl.constexpr = False
369
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
370
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
371
+ WRITE_DQ : tl.constexpr = True
372
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
373
+ OUTPUT_MAX : tl.constexpr = False
374
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
375
+ IS_DIVISIBLE : tl.constexpr = False
376
+ SM_SCALE : tl.constexpr = 0.08838834764831845
377
+ GQA_SHARED_HEADS : tl.constexpr = 4
378
+ HAS_FULL_BLOCKS : tl.constexpr = True
379
+ QK_HEAD_DIM : tl.constexpr = 128
380
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
381
+ V_HEAD_DIM : tl.constexpr = 128
382
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ SAFE_HEAD_DIM : tl.constexpr = True
384
+ USE_TMA : tl.constexpr = False
385
+ BLOCK_M : tl.constexpr = 128
386
+ BLOCK_N : tl.constexpr = 64
387
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
388
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
389
+ INDEX_DTYPE : tl.constexpr = tl.int32
390
+
391
+
392
+ # -- load k --
393
+ # NB reversed order to since K is transposed
394
+ kv_base_offset = kv_start + kv_offset
395
+
396
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
397
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
398
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
399
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
400
+
401
+ k = tl.trans(k)
402
+ # -- compute qk ---
403
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
404
+ if not PRESCALE_QK:
405
+ qk *= SM_SCALE
406
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
407
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
408
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
409
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
410
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
411
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
412
+
413
+ tmp0 = (qk)
414
+ post_mod_scores = tmp0
415
+
416
+
417
+ if CHECK_BLOCK_BOUNDARY:
418
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
419
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
420
+
421
+ if not IS_FULL_BLOCKS:
422
+ tmp1 = (m)
423
+ tmp2 = tl.full([1], 0, tl.int32)
424
+ tmp3 = tmp1 < tmp2
425
+ tmp4 = (n)
426
+ tmp5 = tmp4 <= tmp1
427
+ tmp6 = tmp3 & tmp5
428
+ tmp7 = tmp1 >= tmp2
429
+ tmp8 = tmp4 < tmp2
430
+ tmp9 = tmp7 & tmp8
431
+ tmp10 = tmp8 == 0
432
+ tmp11 = tmp7 & tmp10
433
+ tmp12 = tmp1 - tmp2
434
+ tmp13 = tl.full([1], 16, tl.int32)
435
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
436
+ tmp15 = tmp4 - tmp2
437
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
438
+ tmp17 = tmp14 == tmp16
439
+ tmp18 = tmp11 & tmp17
440
+ tmp19 = tmp9 | tmp18
441
+ tmp20 = tmp6 | tmp19
442
+ mask_mod_output = tmp20
443
+
444
+
445
+ if CHECK_BLOCK_BOUNDARY:
446
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
447
+ # apply mask for partially unmasked blocks
448
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
449
+
450
+ if not PRESCALE_QK:
451
+ post_mod_scores *= RCP_LN2
452
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
453
+
454
+ # -- compute scaling constant ---
455
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
456
+ if not ROWS_GUARANTEED_SAFE:
457
+ masked_out_rows = (m_ij == float("-inf"))
458
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
459
+ else:
460
+ m_ij_masked = m_ij
461
+
462
+ alpha = tl.math.exp2(m_i - m_ij_masked)
463
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
464
+
465
+ # NB: l_i update is pulled up here since it's a bit faster
466
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
467
+ # m_ij
468
+ l_i = l_i * alpha + tl.sum(p, 1)
469
+ # # -- scale and update acc --
470
+ acc = acc * alpha[:, None]
471
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
472
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
473
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
474
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
475
+
476
+ # -- update m_i
477
+ m_i = m_ij
478
+
479
+ return acc, l_i, m_i
480
+
481
+ @triton.jit
482
+ def forward_inner(
483
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
484
+ q, K, V,
485
+ desc_k, desc_v, Q_LEN, KV_LEN,
486
+ # accumulated values
487
+ acc, l_i, m_i,
488
+ # Offsets used as inputs to score_mod & mask_mod
489
+ # of size [BLOCK_M, BLOCK_N] or scalar.
490
+ off_z, off_h, offs_m, offs_n,
491
+ # Offsets needed for TMA loads
492
+ kv_start,
493
+ # blocksparse data
494
+ kv_indices, kv_num_blocks,
495
+ # start kv and end kv block
496
+ block_n_start, block_n_end,
497
+ MATMUL_PRECISION,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ ):
502
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
503
+ PRESCALE_QK : tl.constexpr = False
504
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
505
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
506
+ WRITE_DQ : tl.constexpr = True
507
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
508
+ OUTPUT_MAX : tl.constexpr = False
509
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
510
+ IS_DIVISIBLE : tl.constexpr = False
511
+ SM_SCALE : tl.constexpr = 0.08838834764831845
512
+ GQA_SHARED_HEADS : tl.constexpr = 4
513
+ HAS_FULL_BLOCKS : tl.constexpr = True
514
+ QK_HEAD_DIM : tl.constexpr = 128
515
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
516
+ V_HEAD_DIM : tl.constexpr = 128
517
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
518
+ SAFE_HEAD_DIM : tl.constexpr = True
519
+ USE_TMA : tl.constexpr = False
520
+ BLOCK_M : tl.constexpr = 128
521
+ BLOCK_N : tl.constexpr = 64
522
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
523
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
524
+ INDEX_DTYPE : tl.constexpr = tl.int32
525
+
526
+
527
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
528
+ RCP_LN2: tl.constexpr = 1.44269504
529
+
530
+ if PRESCALE_QK:
531
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
532
+
533
+ kv_offset = 0
534
+
535
+ # loop over k, v and update accumulator until block_n_end
536
+ for start_n in range(block_n_start, block_n_end):
537
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
538
+ if IS_DIVISIBLE:
539
+ acc, l_i, m_i = forward_block_mn(
540
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
541
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
542
+ # accumulated values
543
+ acc, l_i, m_i,
544
+ # Offsets
545
+ off_z, off_h, offs_m, offs_n,
546
+ # Offsets needed for TMA loads
547
+ kv_start,
548
+ kv_offset,
549
+ MATMUL_PRECISION, RCP_LN2,
550
+ # Strides for K and V
551
+ stride_kk, stride_kn, stride_vn, stride_vk,
552
+ IS_FULL_BLOCKS,
553
+ )
554
+ else:
555
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
556
+ # it's on par or slightly faster than only applying to the last block in fwd.
557
+ # However, we choose different strategy for bwd, where we only apply mod & mask
558
+ # to the last block because it's faster a lot.
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
573
+ )
574
+
575
+
576
+
577
+ offset = get_offset_for_next_block(
578
+ start_n, kv_indices, kv_num_blocks,
579
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
580
+ )
581
+
582
+ offs_n = offs_n + offset
583
+ kv_offset += offset
584
+
585
+
586
+ return acc, l_i, m_i
587
+ ''', device_str='cuda')
588
+
589
+
590
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/qr/cqrng7hmawuvea5b46xnw26e3vaokywqdqnuhn4vt7tmtdoleeab.py
591
+ # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
592
+ # Source node to ATen node mapping:
593
+ # lse_scaled => mul_9
594
+ # Graph fragment:
595
+ # %buf3 : Tensor = PlaceHolder[target=buf3]
596
+ # %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
597
+ # return %mul_9
598
+ triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', '''
599
+ import triton
600
+ import triton.language as tl
601
+
602
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
603
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
604
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
605
+ triton_helpers.set_driver_to_gpu()
606
+
607
+ @triton_heuristics.pointwise(
608
+ size_hints={'x': 4096},
609
+ filename=__file__,
610
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
611
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
612
+ min_elem_per_thread=0
613
+ )
614
+ @triton.jit
615
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
616
+ xoffset = tl.program_id(0) * XBLOCK
617
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
618
+ xmask = xindex < xnumel
619
+ x2 = xindex
620
+ x0 = (xindex % ks0)
621
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
622
+ tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
623
+ tmp1 = 0.6931471805599453
624
+ tmp2 = tmp0 * tmp1
625
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
626
+ ''', device_str='cuda')
627
+
628
+
629
+ async_compile.wait(globals())
630
+ del async_compile
631
+
632
+ class Runner:
633
+ def __init__(self, partitions):
634
+ self.partitions = partitions
635
+
636
+ def recursively_apply_fns(self, fns):
637
+ new_callables = []
638
+ for fn, c in zip(fns, self.partitions):
639
+ new_callables.append(fn(c))
640
+ self.partitions = new_callables
641
+
642
+ def call(self, args):
643
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args
644
+ args.clear()
645
+ s50 = arg0_1
646
+ s0 = arg2_1
647
+ s43 = arg4_1
648
+ s37 = arg7_1
649
+ s71 = arg8_1
650
+ assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
651
+ assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
652
+ assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1))
653
+ assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1))
654
+ assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1))
655
+ assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1))
656
+ assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1))
657
+ assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1))
658
+ assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1))
659
+ assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1))
660
+ assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1))
661
+ with torch.cuda._DeviceGuard(1):
662
+ torch.cuda.set_device(1)
663
+ buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
664
+ buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
665
+ buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
666
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
667
+ stream1 = get_raw_stream(1)
668
+ triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream1)
669
+ del arg10_1
670
+ del arg11_1
671
+ del arg1_1
672
+ del arg3_1
673
+ del arg5_1
674
+ del arg6_1
675
+ del arg9_1
676
+ del buf1
677
+ buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
678
+ # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
679
+ triton_poi_fused_mul_1_xnumel = 32*s37
680
+ stream1 = get_raw_stream(1)
681
+ triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream1)
682
+ del buf0
683
+ return (buf2, buf5, )
684
+
685
+ runner = Runner(partitions=[])
686
+ call = runner.call
687
+ recursively_apply_fns = runner.recursively_apply_fns
688
+
689
+
690
+ def benchmark_compiled_module(times=10, repeat=10):
691
+ from torch._dynamo.testing import rand_strided
692
+ from torch._inductor.utils import print_performance
693
+ arg0_1 = 128
694
+ arg1_1 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16)
695
+ arg2_1 = 128
696
+ arg3_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16)
697
+ arg4_1 = 128
698
+ arg5_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16)
699
+ arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32)
700
+ arg7_1 = 128
701
+ arg8_1 = 128
702
+ arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32)
703
+ arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32)
704
+ arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32)
705
+ arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32)
706
+ arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32)
707
+ arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32)
708
+ arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32)
709
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1])
710
+ return print_performance(fn, times=times, repeat=repeat)
711
+
712
+
713
+ if __name__ == "__main__":
714
+ from torch._inductor.wrapper_benchmark import compiled_module_main
715
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/2z/c2zdv5arszdl6ednyphqfnib6jwgzomr6zt6536b7gq75kp67uvh.py ADDED
@@ -0,0 +1,1046 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['2_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sh/csh76hcjkj7bc6jvydzdmaapo6vnfxlvc3xvqexzngu63td4qnjk.py
38
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # Graph fragment:
41
+ # %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem]
42
+ # %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1]
43
+ # %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf0]
44
+ # %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4" = PlaceHolder[target=tangents_2]
45
+ # %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
46
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
47
+ # return %buf0,%buf1
48
+ triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', '''
49
+ import triton
50
+ import triton.language as tl
51
+
52
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
53
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
54
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
55
+ triton_helpers.set_driver_to_gpu()
56
+
57
+ @triton_heuristics.reduction(
58
+ size_hints={'x': 32768, 'r0_': 128},
59
+ reduction_hint=ReductionHint.DEFAULT,
60
+ filename=__file__,
61
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
62
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
63
+ )
64
+ @triton.jit
65
+ def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
66
+ r0_numel = 128
67
+ rnumel = r0_numel
68
+ RBLOCK: tl.constexpr = R0_BLOCK
69
+ xoffset = tl.program_id(0) * XBLOCK
70
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
71
+ xmask = xindex < xnumel
72
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
73
+ rbase = r0_base
74
+ x0 = (xindex % ks0)
75
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
76
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
77
+ x3 = xindex
78
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
79
+ r0_index = r0_offset + r0_base
80
+ r0_mask = r0_index < r0_numel
81
+ roffset = r0_offset
82
+ rindex = r0_index
83
+ r0_2 = r0_index
84
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
85
+ tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
86
+ tmp2 = tmp0 * tmp1
87
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
88
+ tmp5 = _tmp4 + tmp3
89
+ _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
90
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
91
+ tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
92
+ tmp6 = tmp4.to(tl.float32)
93
+ tmp8 = 0.6931471805599453
94
+ tmp9 = tmp7 * tmp8
95
+ tmp10 = 1.4426950408889634
96
+ tmp11 = tmp9 * tmp10
97
+ tmp12 = tmp6 - tmp11
98
+ tl.store(out_ptr1 + (x3), tmp12, xmask)
99
+ ''', device_str='cuda')
100
+
101
+
102
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ek/cekdygnnt4twwaq4fpapciid2veg5uc5gzwte4mymxe7ertv26cs.py
103
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
104
+ # Source node to ATen node mapping:
105
+ # Graph fragment:
106
+ # %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_2]
107
+ # %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_4]
108
+ # %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_6]
109
+ # %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4" = PlaceHolder[target=getitem_1]
110
+ # %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1]
111
+ # %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1]
112
+ # %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem_3]
113
+ # %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=getitem_5]
114
+ # %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:4" = PlaceHolder[target=primals_13]
115
+ # %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:4" = PlaceHolder[target=primals_9]
116
+ # %primals_20 : Tensor "i32[1, 1, s56][s56, s56, 1]cuda:4" = PlaceHolder[target=primals_20]
117
+ # %primals_23 : Tensor "i32[1, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:4" = PlaceHolder[target=primals_23]
118
+ # %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:4" = PlaceHolder[target=primals_15]
119
+ # %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:4" = PlaceHolder[target=primals_18]
120
+ # %primals_25 : Tensor "i32[1, 1, s100][s100, s100, 1]cuda:4" = PlaceHolder[target=primals_25]
121
+ # %primals_28 : Tensor "i32[1, 1, s5, s10][s10*s5, s10*s5, s10, 1]cuda:4" = PlaceHolder[target=primals_28]
122
+ # %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
123
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
124
+ # return %getitem_4
125
+ triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', '''
126
+ import triton
127
+ import triton.language as tl
128
+
129
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
130
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
131
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
132
+
133
+ @triton_heuristics.template(
134
+
135
+ num_stages=3,
136
+ num_warps=8,
137
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
138
+ inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
139
+
140
+ )
141
+ @triton.jit
142
+ def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7):
143
+ PRESCALE_QK : tl.constexpr = False
144
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
145
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
146
+ WRITE_DQ : tl.constexpr = True
147
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
148
+ OUTPUT_MAX : tl.constexpr = False
149
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
150
+ IS_DIVISIBLE : tl.constexpr = False
151
+ SM_SCALE : tl.constexpr = 0.08838834764831845
152
+ GQA_SHARED_HEADS : tl.constexpr = 4
153
+ HAS_FULL_BLOCKS : tl.constexpr = True
154
+ QK_HEAD_DIM : tl.constexpr = 128
155
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
156
+ V_HEAD_DIM : tl.constexpr = 128
157
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
158
+ SAFE_HEAD_DIM : tl.constexpr = True
159
+ BLOCK_M1 : tl.constexpr = 64
160
+ BLOCK_N1 : tl.constexpr = 128
161
+ BLOCK_M2 : tl.constexpr = 128
162
+ BLOCK_N2 : tl.constexpr = 64
163
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
164
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
165
+ INDEX_DTYPE : tl.constexpr = tl.int32
166
+ Q = arg_Q
167
+ K = arg_K
168
+ V = arg_V
169
+ LSE = arg_LSE
170
+ DELTA = arg_DELTA
171
+ DO = arg_DO
172
+ DQ = arg_DQ
173
+ DV = arg_DV
174
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
175
+ KV_IDX = arg_KV_IDX
176
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
177
+ Q_IDX = arg_Q_IDX
178
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
179
+ FULL_KV_IDX = arg_FULL_KV_IDX
180
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
181
+ FULL_Q_IDX = arg_FULL_Q_IDX
182
+
183
+ # Sub notation for this kernel:
184
+ #
185
+ # Q: Query, K: Key, V: Value
186
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
187
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
188
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
189
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
190
+ # inductor codegen
191
+ # M: Number of queries, N: Number of keys/values
192
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
193
+ # V_HEAD_DIM: The dimension of the value embeddings
194
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
195
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
196
+ # (Modifiable) Performance tuning options
197
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
198
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
199
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
200
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
201
+ #
202
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
203
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
204
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
205
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
206
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
207
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
208
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
209
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
210
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
211
+
212
+ # The below are kernel options that can be applied for certain score_mods,
213
+ # or involve a numerics vs. perf tradeoff
214
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
215
+ # about 20% more numerical error, but slightly faster.
216
+
217
+ # Define strides of inputs
218
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
219
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
220
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
221
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
222
+
223
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
224
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
225
+
226
+ ZQ = 1
227
+ HQ = 32
228
+ HKV = 8
229
+ Q_LEN = ks0
230
+ ZKV = 1
231
+ KV_LEN = ks1
232
+
233
+ MATMUL_PRECISION = Q.dtype.element_ty
234
+
235
+ pid = tl.program_id(0).to(INDEX_DTYPE)
236
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
237
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
238
+
239
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
240
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
241
+ off_zkv = off_zq % ZKV # kv batch idx
242
+
243
+ SPARSE_Z = 1
244
+ SPARSE_HQ = 1
245
+
246
+ sparse_idx_z = off_zq % SPARSE_Z
247
+
248
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
249
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
250
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
251
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
252
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
253
+
254
+ # offset K, V, DV pointers for batch/kv-head
255
+ K += k_adj
256
+ V += v_adj
257
+ DV += dv_adj
258
+
259
+ RCP_LN2 = 1.44269504
260
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
261
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
262
+
263
+ if pid >= NUM_KV_BLOCKS:
264
+ off_pid = pid - NUM_KV_BLOCKS
265
+ # THIS BLOCK DOES DQ
266
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
267
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
268
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
269
+ start_m2_block = off_pid % NUM_Q_BLOCKS
270
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
271
+ stride_kv_num_blks_h = ks2
272
+ stride_kv_idx_h = ks3*ks4
273
+ stride_kv_idx_m = ks4
274
+
275
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
276
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
277
+
278
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
279
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
280
+
281
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
282
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
283
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
284
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
285
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
286
+
287
+ Q2 = Q + q_adj2
288
+ DO2 = DO + do_adj2
289
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
290
+ # if Q is broadcasted)
291
+ DQ2 = DQ + dq_adj2
292
+ LSE2 = LSE + off_chz2
293
+ DELTA2 = DELTA + off_chz2
294
+
295
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
296
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
297
+
298
+ start_m2 = start_m2_block * BLOCK_M2
299
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
300
+
301
+ # load Q and do: they stay in SRAM throughout the inner loop.
302
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
303
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
304
+
305
+ if PRESCALE_QK:
306
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
307
+
308
+ if IS_DIVISIBLE:
309
+ Di = tl.load(DELTA2 + offs_m2)
310
+ lse = tl.load(LSE2 + offs_m2)
311
+ else:
312
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
313
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
314
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
315
+ lse = lse[:, None]
316
+
317
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
318
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
319
+ kv_indices = KV_IDX + sparse_kv_idx_offset
320
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
321
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
322
+
323
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
324
+ dq = bwd_dq_inner(
325
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
326
+ K, V,
327
+ dq, q, do, Di, lse,
328
+ off_zq, off_hq2, offs_m2, offs_n2,
329
+ stride_kn, stride_kd, stride_vn, stride_vd,
330
+ kv_indices, sparse_kv_num_blocks,
331
+ MATMUL_PRECISION,
332
+ IS_FULL_BLOCKS=False,
333
+ )
334
+
335
+ if HAS_FULL_BLOCKS:
336
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
337
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
338
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
339
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
340
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
341
+
342
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
343
+ dq = bwd_dq_inner(
344
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
345
+ K, V,
346
+ dq, q, do, Di, lse,
347
+ off_zq, off_hq2, offs_m2, offs_n2,
348
+ stride_kn, stride_kd, stride_vn, stride_vd,
349
+ kv_indices, sparse_kv_num_blocks,
350
+ MATMUL_PRECISION,
351
+ IS_FULL_BLOCKS=True,
352
+ )
353
+
354
+ # Write back dQ.
355
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
356
+ dq *= SM_SCALE
357
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
358
+ tl.store(dq_ptrs, dq)
359
+ else:
360
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
361
+ else:
362
+ # THIS BLOCK DOES DK & DV
363
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
364
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
365
+
366
+ pid_mask = pid // SPARSE_KV_MULTIPLE
367
+
368
+ stride_q_num_blks_h = ks5
369
+ stride_q_idx_h = ks6*ks7
370
+ stride_q_idx_n = ks6
371
+
372
+
373
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
374
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
375
+
376
+ start_n1 = pid * BLOCK_N1
377
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
378
+
379
+ # load K and V: they stay in SRAM throughout the inner loop.
380
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
381
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
382
+
383
+ if PRESCALE_QK:
384
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
385
+
386
+ for off_g in range(0, GQA_SHARED_HEADS):
387
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
388
+
389
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
390
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
391
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
392
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
393
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
394
+
395
+ Q1 = Q + q_adj1
396
+ DO1 = DO + do_adj1
397
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
398
+ # if Q is broadcasted)
399
+ LSE1 = LSE + off_chz1
400
+ DELTA1 = DELTA + off_chz1
401
+
402
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
403
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
404
+
405
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
406
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
407
+
408
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
409
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
410
+ q_indices = Q_IDX + sparse_q_idx_offset
411
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
412
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
413
+
414
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
415
+ dk, dv = bwd_dkdv_inner(
416
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
417
+ Q1, DO1, DELTA1, LSE1,
418
+ dk, dv, k, v,
419
+ off_zq, off_hq1, offs_n1, offs_m1,
420
+ stride_qm, stride_qd, stride_dom, stride_dod,
421
+ q_indices, sparse_q_num_blocks,
422
+ MATMUL_PRECISION,
423
+ IS_FULL_BLOCKS=False,
424
+ )
425
+
426
+
427
+ if HAS_FULL_BLOCKS:
428
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
429
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
430
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
431
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
432
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
433
+
434
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
435
+ dk, dv = bwd_dkdv_inner(
436
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
437
+ Q1, DO1, DELTA1, LSE1,
438
+ dk, dv, k, v,
439
+ off_zq, off_hq1, offs_n1, offs_m1,
440
+ stride_qm, stride_qd, stride_dom, stride_dod,
441
+ q_indices, sparse_q_num_blocks,
442
+ MATMUL_PRECISION,
443
+ IS_FULL_BLOCKS=True,
444
+ )
445
+
446
+ # Write back dV and dK.
447
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
448
+
449
+ index_n = offs_n1[:, None]
450
+ index_k = offs_k[None, :]
451
+ index_v = offs_v[None, :]
452
+
453
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
454
+ tl.store(dv_ptrs, dv)
455
+ else:
456
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
457
+
458
+ dk *= SM_SCALE
459
+
460
+ if SAFE_HEAD_DIM:
461
+ mask = index_n < KV_LEN
462
+ else:
463
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
464
+
465
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
466
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
467
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
468
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
469
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
470
+
471
+ @triton.jit
472
+ def bwd_dq_inner(
473
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
474
+ K, V, # pointers
475
+ dq, q, do, Di, lse,
476
+ off_z, off_hq, offs_m2, offs_n2,
477
+ stride_kn, stride_kd, stride_vn, stride_vd,
478
+ kv_indices, sparse_kv_num_blocks,
479
+ MATMUL_PRECISION,
480
+ IS_FULL_BLOCKS,
481
+ ):
482
+ PRESCALE_QK : tl.constexpr = False
483
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
484
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
485
+ WRITE_DQ : tl.constexpr = True
486
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
487
+ OUTPUT_MAX : tl.constexpr = False
488
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
489
+ IS_DIVISIBLE : tl.constexpr = False
490
+ SM_SCALE : tl.constexpr = 0.08838834764831845
491
+ GQA_SHARED_HEADS : tl.constexpr = 4
492
+ HAS_FULL_BLOCKS : tl.constexpr = True
493
+ QK_HEAD_DIM : tl.constexpr = 128
494
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
495
+ V_HEAD_DIM : tl.constexpr = 128
496
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
497
+ SAFE_HEAD_DIM : tl.constexpr = True
498
+ BLOCK_M1 : tl.constexpr = 64
499
+ BLOCK_N1 : tl.constexpr = 128
500
+ BLOCK_M2 : tl.constexpr = 128
501
+ BLOCK_N2 : tl.constexpr = 64
502
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
503
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
504
+ INDEX_DTYPE : tl.constexpr = tl.int32
505
+
506
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
507
+ RCP_LN2: tl.constexpr = 1.44269504
508
+ Q_LEN = ks0
509
+ KV_LEN = ks1
510
+
511
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
512
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
513
+
514
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
515
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
516
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
517
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
518
+
519
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
520
+
521
+ for start_n in range(0, hi):
522
+ dq = bwd_dq_block_mn(
523
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
524
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
525
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
526
+ stride_kn, stride_kd, stride_vn, stride_vd,
527
+ kv_indices, sparse_kv_num_blocks,
528
+ MATMUL_PRECISION, RCP_LN2,
529
+ IS_FULL_BLOCKS,
530
+ )
531
+
532
+ # Increment pointers.
533
+ offset = get_offset_for_next_block(
534
+ start_n, kv_indices, sparse_kv_num_blocks,
535
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
536
+ )
537
+
538
+ kT_ptrs += offset * stride_kn
539
+ vT_ptrs += offset * stride_vn
540
+
541
+ offs_n2 += offset
542
+
543
+ return dq
544
+
545
+
546
+ @triton.jit
547
+ def bwd_dq_block_mn(
548
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
549
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
550
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
551
+ stride_kn, stride_kd, stride_vn, stride_vd,
552
+ kv_indices, sparse_kv_num_blocks,
553
+ MATMUL_PRECISION, RCP_LN2,
554
+ IS_FULL_BLOCKS,
555
+ ):
556
+ PRESCALE_QK : tl.constexpr = False
557
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
558
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
559
+ WRITE_DQ : tl.constexpr = True
560
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
561
+ OUTPUT_MAX : tl.constexpr = False
562
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
563
+ IS_DIVISIBLE : tl.constexpr = False
564
+ SM_SCALE : tl.constexpr = 0.08838834764831845
565
+ GQA_SHARED_HEADS : tl.constexpr = 4
566
+ HAS_FULL_BLOCKS : tl.constexpr = True
567
+ QK_HEAD_DIM : tl.constexpr = 128
568
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
569
+ V_HEAD_DIM : tl.constexpr = 128
570
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
571
+ SAFE_HEAD_DIM : tl.constexpr = True
572
+ BLOCK_M1 : tl.constexpr = 64
573
+ BLOCK_N1 : tl.constexpr = 128
574
+ BLOCK_M2 : tl.constexpr = 128
575
+ BLOCK_N2 : tl.constexpr = 64
576
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
577
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
578
+ INDEX_DTYPE : tl.constexpr = tl.int32
579
+
580
+
581
+ # NB reversed order to since K is transposed
582
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
583
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
584
+ if not PRESCALE_QK:
585
+ qk *= SM_SCALE
586
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
587
+ pre_mod_scores = qk
588
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
589
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
590
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
591
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
592
+
593
+ tmp0 = (qk)
594
+ post_mod_scores = tmp0
595
+
596
+
597
+
598
+
599
+ if not IS_DIVISIBLE:
600
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
601
+
602
+ if not IS_FULL_BLOCKS:
603
+ tmp1 = (m)
604
+ tmp2 = tl.full([1], 0, tl.int32)
605
+ tmp3 = tmp1 < tmp2
606
+ tmp4 = (n)
607
+ tmp5 = tmp4 <= tmp1
608
+ tmp6 = tmp3 & tmp5
609
+ tmp7 = tmp1 >= tmp2
610
+ tmp8 = tmp4 < tmp2
611
+ tmp9 = tmp7 & tmp8
612
+ tmp10 = tmp8 == 0
613
+ tmp11 = tmp7 & tmp10
614
+ tmp12 = tmp1 - tmp2
615
+ tmp13 = tl.full([1], 16, tl.int32)
616
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
617
+ tmp15 = tmp4 - tmp2
618
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
619
+ tmp17 = tmp14 == tmp16
620
+ tmp18 = tmp11 & tmp17
621
+ tmp19 = tmp9 | tmp18
622
+ tmp20 = tmp6 | tmp19
623
+ mask_mod_output = tmp20
624
+
625
+
626
+ # apply mask for partial masked block
627
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
628
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
629
+ if not PRESCALE_QK:
630
+ post_mod_scores *= RCP_LN2
631
+ p = tl.math.exp2(post_mod_scores - lse)
632
+ # Compute dP and dS.
633
+ # NB reversed order to since V is transposed
634
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
635
+
636
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
637
+ ds = p * (dp - Di[:, None])
638
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
639
+ tmp21 = (ds)
640
+ grad_scores = tmp21
641
+
642
+
643
+ if not IS_DIVISIBLE:
644
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
645
+
646
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
647
+ if WRITE_DQ:
648
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
649
+
650
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
651
+ ds = grad_scores
652
+
653
+ if not IS_FULL_BLOCKS:
654
+ # (grads) apply mask for partially unmasked block
655
+ ds = tl.where(mask_mod_output, ds, 0.0)
656
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
657
+ ds = ds.to(MATMUL_PRECISION)
658
+ # Compute dQ.
659
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
660
+
661
+ return dq
662
+
663
+
664
+ @triton.jit
665
+ def bwd_dkdv_inner(
666
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
667
+ Q, DO, DELTA, LSE, # pointers
668
+ dk, dv, k, v,
669
+ off_z, off_hq, offs_n1, offs_m1,
670
+ stride_qm, stride_qd, stride_dom, stride_dod,
671
+ q_indices, sparse_q_num_blocks,
672
+ MATMUL_PRECISION,
673
+ IS_FULL_BLOCKS,
674
+ ):
675
+ PRESCALE_QK : tl.constexpr = False
676
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
677
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
678
+ WRITE_DQ : tl.constexpr = True
679
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
680
+ OUTPUT_MAX : tl.constexpr = False
681
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
682
+ IS_DIVISIBLE : tl.constexpr = False
683
+ SM_SCALE : tl.constexpr = 0.08838834764831845
684
+ GQA_SHARED_HEADS : tl.constexpr = 4
685
+ HAS_FULL_BLOCKS : tl.constexpr = True
686
+ QK_HEAD_DIM : tl.constexpr = 128
687
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
688
+ V_HEAD_DIM : tl.constexpr = 128
689
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
690
+ SAFE_HEAD_DIM : tl.constexpr = True
691
+ BLOCK_M1 : tl.constexpr = 64
692
+ BLOCK_N1 : tl.constexpr = 128
693
+ BLOCK_M2 : tl.constexpr = 128
694
+ BLOCK_N2 : tl.constexpr = 64
695
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
696
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
697
+ INDEX_DTYPE : tl.constexpr = tl.int32
698
+
699
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
700
+ RCP_LN2: tl.constexpr = 1.44269504
701
+ Q_LEN = ks0
702
+ KV_LEN = ks1
703
+
704
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
705
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
706
+
707
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
708
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
709
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
710
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
711
+
712
+ # The minimum is needed to handle the case where we run with a super large
713
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
714
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
715
+
716
+ for start_m in range(0, hi):
717
+ dk, dv = bwd_dkdv_block_mn(
718
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
719
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
720
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
721
+ stride_qm, stride_qd, stride_dom, stride_dod,
722
+ q_indices, sparse_q_num_blocks,
723
+ MATMUL_PRECISION, RCP_LN2,
724
+ IS_FULL_BLOCKS,
725
+ )
726
+ # Increment pointers.
727
+ offset = get_offset_for_next_block(
728
+ start_m, q_indices, sparse_q_num_blocks,
729
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
730
+ )
731
+
732
+ qT_ptrs += offset * stride_qm
733
+ do_ptrs += offset * stride_dom
734
+ offs_m1 += offset
735
+
736
+ return dk, dv
737
+
738
+
739
+ @triton.jit
740
+ def bwd_dkdv_block_mn(
741
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
742
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
743
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
744
+ stride_qm, stride_qd, stride_dom, stride_dod,
745
+ q_indices, sparse_q_num_blocks,
746
+ MATMUL_PRECISION, RCP_LN2,
747
+ IS_FULL_BLOCKS,
748
+ ):
749
+ PRESCALE_QK : tl.constexpr = False
750
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
751
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
752
+ WRITE_DQ : tl.constexpr = True
753
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
754
+ OUTPUT_MAX : tl.constexpr = False
755
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
756
+ IS_DIVISIBLE : tl.constexpr = False
757
+ SM_SCALE : tl.constexpr = 0.08838834764831845
758
+ GQA_SHARED_HEADS : tl.constexpr = 4
759
+ HAS_FULL_BLOCKS : tl.constexpr = True
760
+ QK_HEAD_DIM : tl.constexpr = 128
761
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
762
+ V_HEAD_DIM : tl.constexpr = 128
763
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
764
+ SAFE_HEAD_DIM : tl.constexpr = True
765
+ BLOCK_M1 : tl.constexpr = 64
766
+ BLOCK_N1 : tl.constexpr = 128
767
+ BLOCK_M2 : tl.constexpr = 128
768
+ BLOCK_N2 : tl.constexpr = 64
769
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
770
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
771
+ INDEX_DTYPE : tl.constexpr = tl.int32
772
+
773
+
774
+ # NB reversed order since Q is transposed
775
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
776
+ # Load LSE before computing qk to reduce pipeline stall.
777
+ if IS_DIVISIBLE:
778
+ lse = tl.load(LSE + offs_m1)
779
+ else:
780
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
781
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
782
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
783
+ if not PRESCALE_QK:
784
+ qkT *= SM_SCALE
785
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
786
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
787
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
788
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
789
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
790
+
791
+ pre_mod_scores = qkT
792
+ tmp22 = (qkT)
793
+ post_mod_scores = tmp22
794
+
795
+
796
+
797
+ if not IS_DIVISIBLE:
798
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
799
+
800
+ if not IS_FULL_BLOCKS:
801
+ tmp23 = (m)
802
+ tmp24 = tl.full([1], 0, tl.int32)
803
+ tmp25 = tmp23 < tmp24
804
+ tmp26 = (n)
805
+ tmp27 = tmp26 <= tmp23
806
+ tmp28 = tmp25 & tmp27
807
+ tmp29 = tmp23 >= tmp24
808
+ tmp30 = tmp26 < tmp24
809
+ tmp31 = tmp29 & tmp30
810
+ tmp32 = tmp30 == 0
811
+ tmp33 = tmp29 & tmp32
812
+ tmp34 = tmp23 - tmp24
813
+ tmp35 = tl.full([1], 16, tl.int32)
814
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
815
+ tmp37 = tmp26 - tmp24
816
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
817
+ tmp39 = tmp36 == tmp38
818
+ tmp40 = tmp33 & tmp39
819
+ tmp41 = tmp31 | tmp40
820
+ tmp42 = tmp28 | tmp41
821
+ mask_mod_output = tmp42
822
+
823
+ # (grads) apply mask for fully masked block
824
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
825
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
826
+ if not PRESCALE_QK:
827
+ post_mod_scores *= RCP_LN2
828
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
829
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
830
+ # Compute dV.
831
+ ppT = pT
832
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
833
+ if IS_DIVISIBLE:
834
+ Di = tl.load(DELTA + offs_m1)
835
+ else:
836
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
837
+ # Compute dP and dS.
838
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
839
+ dsT = pT * (dpT - Di[None, :])
840
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
841
+ tmp43 = (dsT)
842
+ grad_scores = tmp43
843
+
844
+
845
+
846
+ if not IS_DIVISIBLE:
847
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
848
+
849
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
850
+ if not WRITE_DQ:
851
+ idx_b = off_z
852
+ idx_h = off_hq
853
+ idx_m = m
854
+ idx_n = n
855
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
856
+
857
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
858
+ dsT = grad_scores
859
+ if not IS_FULL_BLOCKS:
860
+ # (grads) apply mask for partially unmasked block
861
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
862
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
863
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
864
+
865
+ return dk, dv
866
+
867
+ # Utility triton funcs
868
+ @triton.jit
869
+ def get_offset_for_next_block(
870
+ loop_iter, col_indices, total_blocks,
871
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
872
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
873
+ ):
874
+ if BLOCKS_ARE_CONTIGUOUS:
875
+ return BLOCK
876
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
877
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
878
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
879
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
880
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
881
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
882
+ return offset
883
+
884
+ @triton.jit
885
+ def get_bounded_indices(indices, max_len=None):
886
+ return indices % max_len if max_len is not None else indices
887
+
888
+ @triton.jit
889
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
890
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
891
+ return tl.load(block_ptr)
892
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
893
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
894
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
895
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
896
+ else:
897
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
898
+
899
+ @triton.jit
900
+ def load_checked_2d(
901
+ ptr,
902
+ offs_m,
903
+ offs_n,
904
+ stride_m,
905
+ stride_n,
906
+ IS_DIVISIBLE_M: tl.constexpr,
907
+ IS_DIVISIBLE_N: tl.constexpr,
908
+ M_LEN: tl.constexpr,
909
+ N_LEN: tl.constexpr,
910
+ ):
911
+ # Calculate final pointer if strides are provided
912
+ if stride_m is not None and stride_n is not None:
913
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
914
+
915
+ # Handle all masking cases
916
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
917
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
918
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
919
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
920
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
921
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
922
+ else: # Both divisible
923
+ return tl.load(ptr)
924
+ ''', device_str='cuda')
925
+
926
+
927
+ async_compile.wait(globals())
928
+ del async_compile
929
+
930
+ class Runner:
931
+ def __init__(self, partitions):
932
+ self.partitions = partitions
933
+
934
+ def recursively_apply_fns(self, fns):
935
+ new_callables = []
936
+ for fn, c in zip(fns, self.partitions):
937
+ new_callables.append(fn(c))
938
+ self.partitions = new_callables
939
+
940
+ def call(self, args):
941
+ primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2 = args
942
+ args.clear()
943
+ s37 = primals_10
944
+ s0 = primals_11
945
+ s22 = primals_7
946
+ s72 = primals_8
947
+ s99 = primals_12
948
+ s94 = primals_14
949
+ s28 = primals_16
950
+ s4 = primals_17
951
+ s56 = primals_19
952
+ s53 = primals_22
953
+ s84 = primals_21
954
+ s100 = primals_24
955
+ s10 = primals_27
956
+ s5 = primals_26
957
+ assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
958
+ assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
959
+ assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
960
+ assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1))
961
+ assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1))
962
+ assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1))
963
+ assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1))
964
+ assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1))
965
+ assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1))
966
+ assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1))
967
+ assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1))
968
+ assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
969
+ assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
970
+ assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1))
971
+ assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
972
+ with torch.cuda._DeviceGuard(4):
973
+ torch.cuda.set_device(4)
974
+ buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
975
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
976
+ triton_red_fused_mul_0_xnumel = 32*s37
977
+ stream4 = get_raw_stream(4)
978
+ triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream4)
979
+ del getitem
980
+ del tangents_2
981
+ buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
982
+ buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
983
+ buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
984
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
985
+ stream4 = get_raw_stream(4)
986
+ triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_20, primals_23, primals_15, primals_18, primals_25, primals_28, buf5, s37, s0, s99, s22, s72, s56, s53, s84, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream4)
987
+ del buf1
988
+ del getitem_1
989
+ del primals_13
990
+ del primals_15
991
+ del primals_18
992
+ del primals_2
993
+ del primals_20
994
+ del primals_23
995
+ del primals_25
996
+ del primals_28
997
+ del primals_4
998
+ del primals_6
999
+ del primals_9
1000
+ del tangents_1
1001
+ return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, )
1002
+
1003
+ runner = Runner(partitions=[])
1004
+ call = runner.call
1005
+ recursively_apply_fns = runner.recursively_apply_fns
1006
+
1007
+
1008
+ def benchmark_compiled_module(times=10, repeat=10):
1009
+ from torch._dynamo.testing import rand_strided
1010
+ from torch._inductor.utils import print_performance
1011
+ primals_10 = 960
1012
+ primals_11 = 960
1013
+ primals_7 = 8
1014
+ primals_8 = 8
1015
+ primals_12 = 8
1016
+ primals_14 = 8
1017
+ primals_16 = 8
1018
+ primals_17 = 8
1019
+ primals_19 = 8
1020
+ primals_22 = 8
1021
+ primals_21 = 8
1022
+ primals_24 = 8
1023
+ primals_27 = 8
1024
+ primals_26 = 8
1025
+ primals_2 = rand_strided((1, 32, 960, 128), (3932160, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16)
1026
+ primals_4 = rand_strided((1, 8, 960, 128), (983040, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16)
1027
+ primals_6 = rand_strided((1, 8, 960, 128), (983040, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16)
1028
+ primals_9 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32)
1029
+ primals_13 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32)
1030
+ primals_15 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32)
1031
+ primals_18 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32)
1032
+ primals_20 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32)
1033
+ primals_23 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32)
1034
+ primals_25 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32)
1035
+ primals_28 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32)
1036
+ getitem = rand_strided((1, 32, 960, 128), (3932160, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16)
1037
+ getitem_1 = rand_strided((1, 32, 960), (30720, 960, 1), device='cuda:4', dtype=torch.float32)
1038
+ tangents_1 = rand_strided((1, 32, 960, 128), (3932160, 122880, 128, 1), device='cuda:4', dtype=torch.bfloat16)
1039
+ tangents_2 = rand_strided((1, 32, 960), (30720, 960, 1), device='cuda:4', dtype=torch.float32)
1040
+ fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2])
1041
+ return print_performance(fn, times=times, repeat=repeat)
1042
+
1043
+
1044
+ if __name__ == "__main__":
1045
+ from torch._inductor.wrapper_benchmark import compiled_module_main
1046
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/2z/c2zqq6qyjomc7iflknbqr7yjdhjux47hzv4nnsi5qfbeqglaip2h.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/cf/ccftanvnrini6kruughcnjtpfiarn7zwa2sdotthpo3wbbjituv3.py
38
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
39
+ # Source node to ATen node mapping:
40
+ # flex_attention => flex_attention
41
+ # Graph fragment:
42
+ # %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_2]
43
+ # %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_4]
44
+ # %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_6]
45
+ # %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=getitem_1]
46
+ # %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1]
47
+ # %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=primals_10]
48
+ # %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=primals_7]
49
+ # %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=primals_11]
50
+ # %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=primals_12]
51
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
52
+ # return %getitem
53
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
54
+ import triton
55
+ import triton.language as tl
56
+
57
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
58
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
59
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
60
+
61
+ @triton_heuristics.template(
62
+
63
+ num_stages=3,
64
+ num_warps=8,
65
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
67
+
68
+ )
69
+ @triton.jit
70
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
71
+ PRESCALE_QK : tl.constexpr = False
72
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
73
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
74
+ WRITE_DQ : tl.constexpr = True
75
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
76
+ OUTPUT_MAX : tl.constexpr = False
77
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
78
+ IS_DIVISIBLE : tl.constexpr = False
79
+ SM_SCALE : tl.constexpr = 0.08838834764831845
80
+ GQA_SHARED_HEADS : tl.constexpr = 4
81
+ HAS_FULL_BLOCKS : tl.constexpr = True
82
+ QK_HEAD_DIM : tl.constexpr = 128
83
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
84
+ V_HEAD_DIM : tl.constexpr = 128
85
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ SAFE_HEAD_DIM : tl.constexpr = True
87
+ USE_TMA : tl.constexpr = False
88
+ BLOCK_M : tl.constexpr = 128
89
+ BLOCK_N : tl.constexpr = 64
90
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
91
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
92
+ INDEX_DTYPE : tl.constexpr = tl.int32
93
+ Q = arg_Q
94
+ K = arg_K
95
+ V = arg_V
96
+ LSE = arg_LSE
97
+ MAX = arg_MAX
98
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
99
+ KV_IDX = arg_KV_IDX
100
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
101
+ FULL_KV_IDX = arg_FULL_KV_IDX
102
+
103
+ # Sub notation for this kernel:
104
+ #
105
+ # Q: Query, K: Key, V: Value
106
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
107
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
108
+ # V_HEAD_DIM: The dimension of the value embeddings
109
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
110
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
111
+ #
112
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
113
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
114
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
115
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
116
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
117
+ #
118
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
119
+ #
120
+ # (Modifiable) Performance tuning options
121
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
122
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
123
+
124
+ # The below are kernel options that can be applied for certain score_mods,
125
+ # or involve a numerics vs. perf tradeoff
126
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
127
+ # about 20% more numerical error, but slightly faster.
128
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
129
+ # is not masked out? If so, we can skip an extra safety check
130
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
131
+ # contiguous? If so, we don't need to do an indirect jump for every block
132
+
133
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
134
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
135
+
136
+ # Define strides of inputs
137
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
138
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
139
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
140
+
141
+ ZQ = 1
142
+ HQ = 32
143
+ Q_LEN = ks0
144
+ ZKV = 1
145
+ KV_LEN = ks1
146
+
147
+ MATMUL_PRECISION = Q.dtype.element_ty
148
+
149
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
150
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
151
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
152
+
153
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
154
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
155
+ off_zkv = off_zq % ZKV
156
+ off_hkv = off_hq // GQA_SHARED_HEADS
157
+ off_g = off_hq % GQA_SHARED_HEADS
158
+
159
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
160
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
161
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
162
+
163
+ Q = Q + q_offset
164
+ K = K + k_offset
165
+ V = V + v_offset
166
+
167
+ # Setting up the TMA descriptors for Q, K, V
168
+ desc_q = None
169
+ desc_k = None
170
+ desc_v = None
171
+
172
+ SPARSE_Z = 1
173
+ SPARSE_HQ = 1
174
+
175
+ sparse_idx_z = off_zq % SPARSE_Z
176
+ sparse_idx_hq = off_hq % SPARSE_HQ
177
+
178
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
179
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
180
+
181
+ stride_kv_num_blks_h = 1
182
+ stride_kv_idx_h = 1
183
+ stride_kv_idx_m = 1
184
+
185
+ # initialize pointer to m and l
186
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
187
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
188
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
189
+
190
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
191
+
192
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
193
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
194
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
195
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
196
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
197
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
198
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
199
+
200
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
201
+ # We don't know anything "special" about these blocks, so we need to apply
202
+ # both score_mod and mask_mod to it
203
+ kv_indices = KV_IDX + sparse_kv_idx_offset
204
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
205
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
206
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
207
+
208
+
209
+ # K and V pointers will be passed directly to forward_inner
210
+
211
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
212
+
213
+
214
+ acc, l_i, m_i = forward_inner(
215
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
216
+ q, K, V,
217
+ desc_k, desc_v, Q_LEN, KV_LEN,
218
+ acc, l_i, m_i,
219
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
220
+ kv_start,
221
+ kv_indices, kv_num_blocks,
222
+ 0, block_n_end,
223
+ MATMUL_PRECISION,
224
+ stride_kk, stride_kn, stride_vn, stride_vk,
225
+ IS_FULL_BLOCKS=False,
226
+ )
227
+
228
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
229
+ # We know these blocks are guaranteed to be "full", so we don't need to
230
+ # apply mask_mod to them - only score_mod
231
+ if HAS_FULL_BLOCKS:
232
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
233
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
234
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
235
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
236
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
237
+ # K and V pointers will be passed directly to forward_inner
238
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
239
+
240
+ acc, l_i, m_i = forward_inner(
241
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
242
+ q, K, V,
243
+ desc_k, desc_v, Q_LEN, KV_LEN,
244
+ acc, l_i, m_i,
245
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
246
+ kv_start,
247
+ kv_indices, kv_num_blocks,
248
+ 0, block_n_end,
249
+ MATMUL_PRECISION,
250
+ stride_kk, stride_kn, stride_vn, stride_vk,
251
+ IS_FULL_BLOCKS=True,
252
+ )
253
+
254
+
255
+ # [Note] Handle fully masked out rows:
256
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
257
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
258
+ l_i = tl.where(l_i == 0.0, 1, l_i)
259
+
260
+ acc = acc / l_i[:, None]
261
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
262
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
263
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
264
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
265
+
266
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
267
+
268
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
269
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
270
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
271
+
272
+ if OUTPUT_LOGSUMEXP:
273
+ off_hz = off_zq * HQ + off_hq
274
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
275
+ lse = m_i + tl.math.log2(l_i)
276
+ if IS_DIVISIBLE:
277
+ tl.store(l_ptrs, lse)
278
+ else:
279
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
280
+
281
+ if OUTPUT_MAX:
282
+ off_hz = off_zq * HQ + off_hq
283
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
284
+ if IS_DIVISIBLE:
285
+ tl.store(max_ptrs, m_i)
286
+ else:
287
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
288
+
289
+
290
+ # Utility triton funcs
291
+ @triton.jit
292
+ def get_offset_for_next_block(
293
+ loop_iter, col_indices, total_blocks,
294
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
295
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
296
+ ):
297
+ if BLOCKS_ARE_CONTIGUOUS:
298
+ return BLOCK
299
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
300
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
301
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
302
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
303
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
304
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
305
+ return offset
306
+
307
+ @triton.jit
308
+ def get_bounded_indices(indices, max_len=None):
309
+ return indices % max_len if max_len is not None else indices
310
+
311
+ @triton.jit
312
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
313
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
314
+ return tl.load(block_ptr)
315
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
317
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
319
+ else:
320
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
321
+
322
+ @triton.jit
323
+ def load_checked_2d(
324
+ ptr,
325
+ offs_m,
326
+ offs_n,
327
+ stride_m,
328
+ stride_n,
329
+ IS_DIVISIBLE_M: tl.constexpr,
330
+ IS_DIVISIBLE_N: tl.constexpr,
331
+ M_LEN: tl.constexpr,
332
+ N_LEN: tl.constexpr,
333
+ ):
334
+ # Calculate final pointer if strides are provided
335
+ if stride_m is not None and stride_n is not None:
336
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
337
+
338
+ # Handle all masking cases
339
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
340
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
341
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
343
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
345
+ else: # Both divisible
346
+ return tl.load(ptr)
347
+
348
+
349
+ # Common Imports
350
+ @triton.jit
351
+ def forward_block_mn(
352
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
353
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
354
+ # accumulated values
355
+ acc, l_i, m_i,
356
+ # Offsets
357
+ off_z, off_h, offs_m, offs_n,
358
+ # Offsets needed for TMA loads
359
+ kv_start,
360
+ kv_offset,
361
+ MATMUL_PRECISION, RCP_LN2,
362
+ # Strides for K and V
363
+ stride_kk, stride_kn, stride_vn, stride_vk,
364
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
365
+
366
+ ):
367
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
368
+ PRESCALE_QK : tl.constexpr = False
369
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
370
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
371
+ WRITE_DQ : tl.constexpr = True
372
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
373
+ OUTPUT_MAX : tl.constexpr = False
374
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
375
+ IS_DIVISIBLE : tl.constexpr = False
376
+ SM_SCALE : tl.constexpr = 0.08838834764831845
377
+ GQA_SHARED_HEADS : tl.constexpr = 4
378
+ HAS_FULL_BLOCKS : tl.constexpr = True
379
+ QK_HEAD_DIM : tl.constexpr = 128
380
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
381
+ V_HEAD_DIM : tl.constexpr = 128
382
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ SAFE_HEAD_DIM : tl.constexpr = True
384
+ USE_TMA : tl.constexpr = False
385
+ BLOCK_M : tl.constexpr = 128
386
+ BLOCK_N : tl.constexpr = 64
387
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
388
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
389
+ INDEX_DTYPE : tl.constexpr = tl.int32
390
+
391
+
392
+ # -- load k --
393
+ # NB reversed order to since K is transposed
394
+ kv_base_offset = kv_start + kv_offset
395
+
396
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
397
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
398
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
399
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
400
+
401
+ k = tl.trans(k)
402
+ # -- compute qk ---
403
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
404
+ if not PRESCALE_QK:
405
+ qk *= SM_SCALE
406
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
407
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
408
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
409
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
410
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
411
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
412
+
413
+ tmp0 = (qk)
414
+ post_mod_scores = tmp0
415
+
416
+
417
+ if CHECK_BLOCK_BOUNDARY:
418
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
419
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
420
+
421
+ if not IS_FULL_BLOCKS:
422
+ tmp1 = (m)
423
+ tmp2 = tl.full([1], 0, tl.int32)
424
+ tmp3 = tmp1 < tmp2
425
+ tmp4 = (n)
426
+ tmp5 = tmp4 <= tmp1
427
+ tmp6 = tmp3 & tmp5
428
+ tmp7 = tmp1 >= tmp2
429
+ tmp8 = tmp4 < tmp2
430
+ tmp9 = tmp7 & tmp8
431
+ tmp10 = tmp8 == 0
432
+ tmp11 = tmp7 & tmp10
433
+ tmp12 = tmp1 - tmp2
434
+ tmp13 = tl.full([1], 16, tl.int32)
435
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
436
+ tmp15 = tmp4 - tmp2
437
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
438
+ tmp17 = tmp14 == tmp16
439
+ tmp18 = tmp11 & tmp17
440
+ tmp19 = tmp9 | tmp18
441
+ tmp20 = tmp6 | tmp19
442
+ mask_mod_output = tmp20
443
+
444
+
445
+ if CHECK_BLOCK_BOUNDARY:
446
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
447
+ # apply mask for partially unmasked blocks
448
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
449
+
450
+ if not PRESCALE_QK:
451
+ post_mod_scores *= RCP_LN2
452
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
453
+
454
+ # -- compute scaling constant ---
455
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
456
+ if not ROWS_GUARANTEED_SAFE:
457
+ masked_out_rows = (m_ij == float("-inf"))
458
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
459
+ else:
460
+ m_ij_masked = m_ij
461
+
462
+ alpha = tl.math.exp2(m_i - m_ij_masked)
463
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
464
+
465
+ # NB: l_i update is pulled up here since it's a bit faster
466
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
467
+ # m_ij
468
+ l_i = l_i * alpha + tl.sum(p, 1)
469
+ # # -- scale and update acc --
470
+ acc = acc * alpha[:, None]
471
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
472
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
473
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
474
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
475
+
476
+ # -- update m_i
477
+ m_i = m_ij
478
+
479
+ return acc, l_i, m_i
480
+
481
+ @triton.jit
482
+ def forward_inner(
483
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
484
+ q, K, V,
485
+ desc_k, desc_v, Q_LEN, KV_LEN,
486
+ # accumulated values
487
+ acc, l_i, m_i,
488
+ # Offsets used as inputs to score_mod & mask_mod
489
+ # of size [BLOCK_M, BLOCK_N] or scalar.
490
+ off_z, off_h, offs_m, offs_n,
491
+ # Offsets needed for TMA loads
492
+ kv_start,
493
+ # blocksparse data
494
+ kv_indices, kv_num_blocks,
495
+ # start kv and end kv block
496
+ block_n_start, block_n_end,
497
+ MATMUL_PRECISION,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ ):
502
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
503
+ PRESCALE_QK : tl.constexpr = False
504
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
505
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
506
+ WRITE_DQ : tl.constexpr = True
507
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
508
+ OUTPUT_MAX : tl.constexpr = False
509
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
510
+ IS_DIVISIBLE : tl.constexpr = False
511
+ SM_SCALE : tl.constexpr = 0.08838834764831845
512
+ GQA_SHARED_HEADS : tl.constexpr = 4
513
+ HAS_FULL_BLOCKS : tl.constexpr = True
514
+ QK_HEAD_DIM : tl.constexpr = 128
515
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
516
+ V_HEAD_DIM : tl.constexpr = 128
517
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
518
+ SAFE_HEAD_DIM : tl.constexpr = True
519
+ USE_TMA : tl.constexpr = False
520
+ BLOCK_M : tl.constexpr = 128
521
+ BLOCK_N : tl.constexpr = 64
522
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
523
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
524
+ INDEX_DTYPE : tl.constexpr = tl.int32
525
+
526
+
527
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
528
+ RCP_LN2: tl.constexpr = 1.44269504
529
+
530
+ if PRESCALE_QK:
531
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
532
+
533
+ kv_offset = 0
534
+
535
+ # loop over k, v and update accumulator until block_n_end
536
+ for start_n in range(block_n_start, block_n_end):
537
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
538
+ if IS_DIVISIBLE:
539
+ acc, l_i, m_i = forward_block_mn(
540
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
541
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
542
+ # accumulated values
543
+ acc, l_i, m_i,
544
+ # Offsets
545
+ off_z, off_h, offs_m, offs_n,
546
+ # Offsets needed for TMA loads
547
+ kv_start,
548
+ kv_offset,
549
+ MATMUL_PRECISION, RCP_LN2,
550
+ # Strides for K and V
551
+ stride_kk, stride_kn, stride_vn, stride_vk,
552
+ IS_FULL_BLOCKS,
553
+ )
554
+ else:
555
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
556
+ # it's on par or slightly faster than only applying to the last block in fwd.
557
+ # However, we choose different strategy for bwd, where we only apply mod & mask
558
+ # to the last block because it's faster a lot.
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
573
+ )
574
+
575
+
576
+
577
+ offset = get_offset_for_next_block(
578
+ start_n, kv_indices, kv_num_blocks,
579
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
580
+ )
581
+
582
+ offs_n = offs_n + offset
583
+ kv_offset += offset
584
+
585
+
586
+ return acc, l_i, m_i
587
+ ''', device_str='cuda')
588
+
589
+
590
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/uu/cuu2rr2yygwarlbfvcbucg7erbfsky4wxudbfsdny5wzgxewg4ut.py
591
+ # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
592
+ # Source node to ATen node mapping:
593
+ # lse_scaled => mul_15
594
+ # Graph fragment:
595
+ # %buf3 : Tensor = PlaceHolder[target=buf3]
596
+ # %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
597
+ # return %mul_15
598
+ triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', '''
599
+ import triton
600
+ import triton.language as tl
601
+
602
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
603
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
604
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
605
+ triton_helpers.set_driver_to_gpu()
606
+
607
+ @triton_heuristics.pointwise(
608
+ size_hints={'x': 4096},
609
+ filename=__file__,
610
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
611
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
612
+ min_elem_per_thread=0
613
+ )
614
+ @triton.jit
615
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
616
+ xoffset = tl.program_id(0) * XBLOCK
617
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
618
+ xmask = xindex < xnumel
619
+ x2 = xindex
620
+ x0 = (xindex % ks0)
621
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
622
+ tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
623
+ tmp1 = 0.6931471805599453
624
+ tmp2 = tmp0 * tmp1
625
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
626
+ ''', device_str='cuda')
627
+
628
+
629
+ async_compile.wait(globals())
630
+ del async_compile
631
+
632
+ class Runner:
633
+ def __init__(self, partitions):
634
+ self.partitions = partitions
635
+
636
+ def recursively_apply_fns(self, fns):
637
+ new_callables = []
638
+ for fn, c in zip(fns, self.partitions):
639
+ new_callables.append(fn(c))
640
+ self.partitions = new_callables
641
+
642
+ def call(self, args):
643
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args
644
+ args.clear()
645
+ s50 = primals_1
646
+ s0 = primals_3
647
+ s43 = primals_5
648
+ s37 = primals_8
649
+ s71 = primals_9
650
+ assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
651
+ assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
652
+ assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
653
+ assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1))
654
+ assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1))
655
+ assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1))
656
+ assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1))
657
+ assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1))
658
+ assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1))
659
+ assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1))
660
+ assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1))
661
+ with torch.cuda._DeviceGuard(0):
662
+ torch.cuda.set_device(0)
663
+ buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
664
+ buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
665
+ buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
666
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
667
+ stream0 = get_raw_stream(0)
668
+ triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, (127 + s37) // 128, 1, 32, stream=stream0)
669
+ del buf1
670
+ buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
671
+ # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
672
+ triton_poi_fused_mul_1_xnumel = 32*s37
673
+ stream0 = get_raw_stream(0)
674
+ triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream0)
675
+ return (buf2, buf5, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf2, buf0, s37, s0, )
676
+
677
+ runner = Runner(partitions=[])
678
+ call = runner.call
679
+ recursively_apply_fns = runner.recursively_apply_fns
680
+
681
+
682
+ def benchmark_compiled_module(times=10, repeat=10):
683
+ from torch._dynamo.testing import rand_strided
684
+ from torch._inductor.utils import print_performance
685
+ primals_1 = 128
686
+ primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16)
687
+ primals_3 = 128
688
+ primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
689
+ primals_5 = 128
690
+ primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
691
+ primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32)
692
+ primals_8 = 128
693
+ primals_9 = 128
694
+ primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32)
695
+ primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32)
696
+ primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32)
697
+ primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32)
698
+ primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32)
699
+ primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32)
700
+ primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32)
701
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16])
702
+ return print_performance(fn, times=times, repeat=repeat)
703
+
704
+
705
+ if __name__ == "__main__":
706
+ from torch._inductor.wrapper_benchmark import compiled_module_main
707
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/32/8d96bbe05a966b7e7756831f09a79e31bf46fad0952af86f36d75557fc1735e8.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 13, "triton_cache_hash": "BGHEC74L2RGBNBI3A4UJOTHXFUUKS4KY3KJKVN65FHLWR47O6USQ"}
progress/SpecForge/cache/compiled_kernels/32/c32pbcuz72bjfnkzvckfbbzlzuupc5yxl7t47b3qf74mmk5g2d2z.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 65536},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 282624}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 35328
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = xindex < xnumel
23
+ x0 = xindex
24
+ tmp0 = tl.load(in_ptr0 + (x0), xmask)
25
+ tmp1 = 0.6931471805599453
26
+ tmp2 = tmp0 * tmp1
27
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
progress/SpecForge/cache/compiled_kernels/3b/a0a6b043ab548fdf71e72bbdf5daab7f72e9ed11a9ad9f8824a6263bb6bc5081.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 14, "triton_cache_hash": "DSCNRRQHW6TSEFKL6AMK6FYZWMIHBTRCG2BE5YK5T7Q76TMOZ5HQ"}
progress/SpecForge/cache/compiled_kernels/3b/c3bqw7dk7k6dcdrp3ycrthotye7y6zb26752jl4lwmfgaybpvr6y.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 65536},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 348160}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 43520
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = xindex < xnumel
23
+ x0 = xindex
24
+ tmp0 = tl.load(in_ptr0 + (x0), xmask)
25
+ tmp1 = 0.6931471805599453
26
+ tmp2 = tmp0 * tmp1
27
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
progress/SpecForge/cache/compiled_kernels/3f/3f6057605b157d44fd56f748226a63975b79198f94871188e73e46cd6c7f8792.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 60, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"}
progress/SpecForge/cache/compiled_kernels/3f/c3fttv7enp2yvnla3r6jkk4galt2qdpxw577ghvkmmx6zqaqla74.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 524288, 'r0_': 32},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 32
20
+ R0_BLOCK: tl.constexpr = 32
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_2 = r0_index
32
+ x5 = xindex
33
+ x1 = xindex // 128
34
+ x0 = (xindex % 128)
35
+ x3 = ((xindex // 128) % ks0)
36
+ x4 = xindex // ks1
37
+ tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None)
38
+ tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
39
+ tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last')
40
+ tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last')
41
+ tmp2 = float("-inf")
42
+ tmp3 = tmp1 == tmp2
43
+ tmp5 = tmp4 - tmp1
44
+ tmp6 = 0.0
45
+ tmp7 = tl.where(tmp3, tmp6, tmp5)
46
+ tmp8 = libdevice.exp2(tmp7)
47
+ tmp9 = tmp0 * tmp8
48
+ tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
49
+ tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
50
+ tmp14 = 1.0
51
+ tmp15 = tl.where(tmp3, tmp14, tmp13)
52
+ tmp16 = (tmp12 / tmp15)
53
+ tmp17 = tmp16.to(tl.float32)
54
+ tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None)
progress/SpecForge/cache/compiled_kernels/3n/c3nlaqknekmjv2zuxzow4rf42v3gorxnfp6uod3dg3ic5ibp6yp3.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['1_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl76p6rje3cyrrbyvxjjj7oxbieltfs4p5xqjre35l6wnofhynby.py
38
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
39
+ # Source node to ATen node mapping:
40
+ # flex_attention => flex_attention
41
+ # Graph fragment:
42
+ # %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=arg1_1]
43
+ # %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg3_1]
44
+ # %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg5_1]
45
+ # %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=getitem_1]
46
+ # %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf1]
47
+ # %arg9_1 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:1" = PlaceHolder[target=arg9_1]
48
+ # %arg6_1 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:1" = PlaceHolder[target=arg6_1]
49
+ # %arg10_1 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:1" = PlaceHolder[target=arg10_1]
50
+ # %arg11_1 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:1" = PlaceHolder[target=arg11_1]
51
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
52
+ # return %getitem
53
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
54
+ import triton
55
+ import triton.language as tl
56
+
57
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
58
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
59
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
60
+
61
+ @triton_heuristics.template(
62
+
63
+ num_stages=3,
64
+ num_warps=8,
65
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
67
+
68
+ )
69
+ @triton.jit
70
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
71
+ PRESCALE_QK : tl.constexpr = False
72
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
73
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
74
+ WRITE_DQ : tl.constexpr = True
75
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
76
+ OUTPUT_MAX : tl.constexpr = False
77
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
78
+ IS_DIVISIBLE : tl.constexpr = False
79
+ SM_SCALE : tl.constexpr = 0.08838834764831845
80
+ GQA_SHARED_HEADS : tl.constexpr = 4
81
+ HAS_FULL_BLOCKS : tl.constexpr = True
82
+ QK_HEAD_DIM : tl.constexpr = 128
83
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
84
+ V_HEAD_DIM : tl.constexpr = 128
85
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ SAFE_HEAD_DIM : tl.constexpr = True
87
+ USE_TMA : tl.constexpr = False
88
+ BLOCK_M : tl.constexpr = 128
89
+ BLOCK_N : tl.constexpr = 64
90
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
91
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
92
+ INDEX_DTYPE : tl.constexpr = tl.int32
93
+ Q = arg_Q
94
+ K = arg_K
95
+ V = arg_V
96
+ LSE = arg_LSE
97
+ MAX = arg_MAX
98
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
99
+ KV_IDX = arg_KV_IDX
100
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
101
+ FULL_KV_IDX = arg_FULL_KV_IDX
102
+
103
+ # Sub notation for this kernel:
104
+ #
105
+ # Q: Query, K: Key, V: Value
106
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
107
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
108
+ # V_HEAD_DIM: The dimension of the value embeddings
109
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
110
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
111
+ #
112
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
113
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
114
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
115
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
116
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
117
+ #
118
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
119
+ #
120
+ # (Modifiable) Performance tuning options
121
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
122
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
123
+
124
+ # The below are kernel options that can be applied for certain score_mods,
125
+ # or involve a numerics vs. perf tradeoff
126
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
127
+ # about 20% more numerical error, but slightly faster.
128
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
129
+ # is not masked out? If so, we can skip an extra safety check
130
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
131
+ # contiguous? If so, we don't need to do an indirect jump for every block
132
+
133
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
134
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
135
+
136
+ # Define strides of inputs
137
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
138
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
139
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
140
+
141
+ ZQ = 1
142
+ HQ = 32
143
+ Q_LEN = ks0
144
+ ZKV = 1
145
+ KV_LEN = ks1
146
+
147
+ MATMUL_PRECISION = Q.dtype.element_ty
148
+
149
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
150
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
151
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
152
+
153
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
154
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
155
+ off_zkv = off_zq % ZKV
156
+ off_hkv = off_hq // GQA_SHARED_HEADS
157
+ off_g = off_hq % GQA_SHARED_HEADS
158
+
159
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
160
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
161
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
162
+
163
+ Q = Q + q_offset
164
+ K = K + k_offset
165
+ V = V + v_offset
166
+
167
+ # Setting up the TMA descriptors for Q, K, V
168
+ desc_q = None
169
+ desc_k = None
170
+ desc_v = None
171
+
172
+ SPARSE_Z = 1
173
+ SPARSE_HQ = 1
174
+
175
+ sparse_idx_z = off_zq % SPARSE_Z
176
+ sparse_idx_hq = off_hq % SPARSE_HQ
177
+
178
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
179
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
180
+
181
+ stride_kv_num_blks_h = 5
182
+ stride_kv_idx_h = 25
183
+ stride_kv_idx_m = 5
184
+
185
+ # initialize pointer to m and l
186
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
187
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
188
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
189
+
190
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
191
+
192
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
193
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
194
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
195
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
196
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
197
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
198
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
199
+
200
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
201
+ # We don't know anything "special" about these blocks, so we need to apply
202
+ # both score_mod and mask_mod to it
203
+ kv_indices = KV_IDX + sparse_kv_idx_offset
204
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
205
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
206
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
207
+
208
+
209
+ # K and V pointers will be passed directly to forward_inner
210
+
211
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
212
+
213
+
214
+ acc, l_i, m_i = forward_inner(
215
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
216
+ q, K, V,
217
+ desc_k, desc_v, Q_LEN, KV_LEN,
218
+ acc, l_i, m_i,
219
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
220
+ kv_start,
221
+ kv_indices, kv_num_blocks,
222
+ 0, block_n_end,
223
+ MATMUL_PRECISION,
224
+ stride_kk, stride_kn, stride_vn, stride_vk,
225
+ IS_FULL_BLOCKS=False,
226
+ )
227
+
228
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
229
+ # We know these blocks are guaranteed to be "full", so we don't need to
230
+ # apply mask_mod to them - only score_mod
231
+ if HAS_FULL_BLOCKS:
232
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
233
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
234
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
235
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
236
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
237
+ # K and V pointers will be passed directly to forward_inner
238
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
239
+
240
+ acc, l_i, m_i = forward_inner(
241
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
242
+ q, K, V,
243
+ desc_k, desc_v, Q_LEN, KV_LEN,
244
+ acc, l_i, m_i,
245
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
246
+ kv_start,
247
+ kv_indices, kv_num_blocks,
248
+ 0, block_n_end,
249
+ MATMUL_PRECISION,
250
+ stride_kk, stride_kn, stride_vn, stride_vk,
251
+ IS_FULL_BLOCKS=True,
252
+ )
253
+
254
+
255
+ # [Note] Handle fully masked out rows:
256
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
257
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
258
+ l_i = tl.where(l_i == 0.0, 1, l_i)
259
+
260
+ acc = acc / l_i[:, None]
261
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
262
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
263
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
264
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
265
+
266
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
267
+
268
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
269
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
270
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
271
+
272
+ if OUTPUT_LOGSUMEXP:
273
+ off_hz = off_zq * HQ + off_hq
274
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
275
+ lse = m_i + tl.math.log2(l_i)
276
+ if IS_DIVISIBLE:
277
+ tl.store(l_ptrs, lse)
278
+ else:
279
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
280
+
281
+ if OUTPUT_MAX:
282
+ off_hz = off_zq * HQ + off_hq
283
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
284
+ if IS_DIVISIBLE:
285
+ tl.store(max_ptrs, m_i)
286
+ else:
287
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
288
+
289
+
290
+ # Utility triton funcs
291
+ @triton.jit
292
+ def get_offset_for_next_block(
293
+ loop_iter, col_indices, total_blocks,
294
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
295
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
296
+ ):
297
+ if BLOCKS_ARE_CONTIGUOUS:
298
+ return BLOCK
299
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
300
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
301
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
302
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
303
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
304
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
305
+ return offset
306
+
307
+ @triton.jit
308
+ def get_bounded_indices(indices, max_len=None):
309
+ return indices % max_len if max_len is not None else indices
310
+
311
+ @triton.jit
312
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
313
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
314
+ return tl.load(block_ptr)
315
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
317
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
319
+ else:
320
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
321
+
322
+ @triton.jit
323
+ def load_checked_2d(
324
+ ptr,
325
+ offs_m,
326
+ offs_n,
327
+ stride_m,
328
+ stride_n,
329
+ IS_DIVISIBLE_M: tl.constexpr,
330
+ IS_DIVISIBLE_N: tl.constexpr,
331
+ M_LEN: tl.constexpr,
332
+ N_LEN: tl.constexpr,
333
+ ):
334
+ # Calculate final pointer if strides are provided
335
+ if stride_m is not None and stride_n is not None:
336
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
337
+
338
+ # Handle all masking cases
339
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
340
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
341
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
343
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
345
+ else: # Both divisible
346
+ return tl.load(ptr)
347
+
348
+
349
+ # Common Imports
350
+ @triton.jit
351
+ def forward_block_mn(
352
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
353
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
354
+ # accumulated values
355
+ acc, l_i, m_i,
356
+ # Offsets
357
+ off_z, off_h, offs_m, offs_n,
358
+ # Offsets needed for TMA loads
359
+ kv_start,
360
+ kv_offset,
361
+ MATMUL_PRECISION, RCP_LN2,
362
+ # Strides for K and V
363
+ stride_kk, stride_kn, stride_vn, stride_vk,
364
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
365
+
366
+ ):
367
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
368
+ PRESCALE_QK : tl.constexpr = False
369
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
370
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
371
+ WRITE_DQ : tl.constexpr = True
372
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
373
+ OUTPUT_MAX : tl.constexpr = False
374
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
375
+ IS_DIVISIBLE : tl.constexpr = False
376
+ SM_SCALE : tl.constexpr = 0.08838834764831845
377
+ GQA_SHARED_HEADS : tl.constexpr = 4
378
+ HAS_FULL_BLOCKS : tl.constexpr = True
379
+ QK_HEAD_DIM : tl.constexpr = 128
380
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
381
+ V_HEAD_DIM : tl.constexpr = 128
382
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ SAFE_HEAD_DIM : tl.constexpr = True
384
+ USE_TMA : tl.constexpr = False
385
+ BLOCK_M : tl.constexpr = 128
386
+ BLOCK_N : tl.constexpr = 64
387
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
388
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
389
+ INDEX_DTYPE : tl.constexpr = tl.int32
390
+
391
+
392
+ # -- load k --
393
+ # NB reversed order to since K is transposed
394
+ kv_base_offset = kv_start + kv_offset
395
+
396
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
397
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
398
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
399
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
400
+
401
+ k = tl.trans(k)
402
+ # -- compute qk ---
403
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
404
+ if not PRESCALE_QK:
405
+ qk *= SM_SCALE
406
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
407
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
408
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
409
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
410
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
411
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
412
+
413
+ tmp0 = (qk)
414
+ post_mod_scores = tmp0
415
+
416
+
417
+ if CHECK_BLOCK_BOUNDARY:
418
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
419
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
420
+
421
+ if not IS_FULL_BLOCKS:
422
+ tmp1 = (m)
423
+ tmp2 = tl.full([1], 0, tl.int32)
424
+ tmp3 = tmp1 < tmp2
425
+ tmp4 = (n)
426
+ tmp5 = tmp4 <= tmp1
427
+ tmp6 = tmp3 & tmp5
428
+ tmp7 = tmp1 >= tmp2
429
+ tmp8 = tmp4 < tmp2
430
+ tmp9 = tmp7 & tmp8
431
+ tmp10 = tmp8 == 0
432
+ tmp11 = tmp7 & tmp10
433
+ tmp12 = tmp1 - tmp2
434
+ tmp13 = tl.full([1], 16, tl.int32)
435
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
436
+ tmp15 = tmp4 - tmp2
437
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
438
+ tmp17 = tmp14 == tmp16
439
+ tmp18 = tmp11 & tmp17
440
+ tmp19 = tmp9 | tmp18
441
+ tmp20 = tmp6 | tmp19
442
+ mask_mod_output = tmp20
443
+
444
+
445
+ if CHECK_BLOCK_BOUNDARY:
446
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
447
+ # apply mask for partially unmasked blocks
448
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
449
+
450
+ if not PRESCALE_QK:
451
+ post_mod_scores *= RCP_LN2
452
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
453
+
454
+ # -- compute scaling constant ---
455
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
456
+ if not ROWS_GUARANTEED_SAFE:
457
+ masked_out_rows = (m_ij == float("-inf"))
458
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
459
+ else:
460
+ m_ij_masked = m_ij
461
+
462
+ alpha = tl.math.exp2(m_i - m_ij_masked)
463
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
464
+
465
+ # NB: l_i update is pulled up here since it's a bit faster
466
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
467
+ # m_ij
468
+ l_i = l_i * alpha + tl.sum(p, 1)
469
+ # # -- scale and update acc --
470
+ acc = acc * alpha[:, None]
471
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
472
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
473
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
474
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
475
+
476
+ # -- update m_i
477
+ m_i = m_ij
478
+
479
+ return acc, l_i, m_i
480
+
481
+ @triton.jit
482
+ def forward_inner(
483
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
484
+ q, K, V,
485
+ desc_k, desc_v, Q_LEN, KV_LEN,
486
+ # accumulated values
487
+ acc, l_i, m_i,
488
+ # Offsets used as inputs to score_mod & mask_mod
489
+ # of size [BLOCK_M, BLOCK_N] or scalar.
490
+ off_z, off_h, offs_m, offs_n,
491
+ # Offsets needed for TMA loads
492
+ kv_start,
493
+ # blocksparse data
494
+ kv_indices, kv_num_blocks,
495
+ # start kv and end kv block
496
+ block_n_start, block_n_end,
497
+ MATMUL_PRECISION,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ ):
502
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
503
+ PRESCALE_QK : tl.constexpr = False
504
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
505
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
506
+ WRITE_DQ : tl.constexpr = True
507
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
508
+ OUTPUT_MAX : tl.constexpr = False
509
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
510
+ IS_DIVISIBLE : tl.constexpr = False
511
+ SM_SCALE : tl.constexpr = 0.08838834764831845
512
+ GQA_SHARED_HEADS : tl.constexpr = 4
513
+ HAS_FULL_BLOCKS : tl.constexpr = True
514
+ QK_HEAD_DIM : tl.constexpr = 128
515
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
516
+ V_HEAD_DIM : tl.constexpr = 128
517
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
518
+ SAFE_HEAD_DIM : tl.constexpr = True
519
+ USE_TMA : tl.constexpr = False
520
+ BLOCK_M : tl.constexpr = 128
521
+ BLOCK_N : tl.constexpr = 64
522
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
523
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
524
+ INDEX_DTYPE : tl.constexpr = tl.int32
525
+
526
+
527
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
528
+ RCP_LN2: tl.constexpr = 1.44269504
529
+
530
+ if PRESCALE_QK:
531
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
532
+
533
+ kv_offset = 0
534
+
535
+ # loop over k, v and update accumulator until block_n_end
536
+ for start_n in range(block_n_start, block_n_end):
537
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
538
+ if IS_DIVISIBLE:
539
+ acc, l_i, m_i = forward_block_mn(
540
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
541
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
542
+ # accumulated values
543
+ acc, l_i, m_i,
544
+ # Offsets
545
+ off_z, off_h, offs_m, offs_n,
546
+ # Offsets needed for TMA loads
547
+ kv_start,
548
+ kv_offset,
549
+ MATMUL_PRECISION, RCP_LN2,
550
+ # Strides for K and V
551
+ stride_kk, stride_kn, stride_vn, stride_vk,
552
+ IS_FULL_BLOCKS,
553
+ )
554
+ else:
555
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
556
+ # it's on par or slightly faster than only applying to the last block in fwd.
557
+ # However, we choose different strategy for bwd, where we only apply mod & mask
558
+ # to the last block because it's faster a lot.
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
573
+ )
574
+
575
+
576
+
577
+ offset = get_offset_for_next_block(
578
+ start_n, kv_indices, kv_num_blocks,
579
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
580
+ )
581
+
582
+ offs_n = offs_n + offset
583
+ kv_offset += offset
584
+
585
+
586
+ return acc, l_i, m_i
587
+ ''', device_str='cuda')
588
+
589
+
590
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ft/cftmlennkcgyn4ynz7zxqohr2jlirziu3mfte3b4eg5y2466jcwm.py
591
+ # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
592
+ # Source node to ATen node mapping:
593
+ # lse_scaled => mul_9
594
+ # Graph fragment:
595
+ # %buf3 : Tensor = PlaceHolder[target=buf3]
596
+ # %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
597
+ # return %mul_9
598
+ triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', '''
599
+ import triton
600
+ import triton.language as tl
601
+
602
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
603
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
604
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
605
+ triton_helpers.set_driver_to_gpu()
606
+
607
+ @triton_heuristics.pointwise(
608
+ size_hints={'x': 32768},
609
+ filename=__file__,
610
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
611
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
612
+ min_elem_per_thread=0
613
+ )
614
+ @triton.jit
615
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
616
+ xoffset = tl.program_id(0) * XBLOCK
617
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
618
+ xmask = xindex < xnumel
619
+ x2 = xindex
620
+ x0 = (xindex % ks0)
621
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
622
+ tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
623
+ tmp1 = 0.6931471805599453
624
+ tmp2 = tmp0 * tmp1
625
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
626
+ ''', device_str='cuda')
627
+
628
+
629
+ async_compile.wait(globals())
630
+ del async_compile
631
+
632
+ class Runner:
633
+ def __init__(self, partitions):
634
+ self.partitions = partitions
635
+
636
+ def recursively_apply_fns(self, fns):
637
+ new_callables = []
638
+ for fn, c in zip(fns, self.partitions):
639
+ new_callables.append(fn(c))
640
+ self.partitions = new_callables
641
+
642
+ def call(self, args):
643
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args
644
+ args.clear()
645
+ s50 = arg0_1
646
+ s0 = arg2_1
647
+ s43 = arg4_1
648
+ s37 = arg7_1
649
+ s71 = arg8_1
650
+ assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
651
+ assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
652
+ assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1))
653
+ assert_size_stride(arg6_1, (1, 1, 5, 5), (25, 25, 5, 1))
654
+ assert_size_stride(arg9_1, (1, 1, 5), (5, 5, 1))
655
+ assert_size_stride(arg10_1, (1, 1, 5), (5, 5, 1))
656
+ assert_size_stride(arg11_1, (1, 1, 5, 5), (25, 25, 5, 1))
657
+ assert_size_stride(arg12_1, (1, 1, 5), (5, 5, 1))
658
+ assert_size_stride(arg13_1, (1, 1, 5, 5), (25, 25, 5, 1))
659
+ assert_size_stride(arg14_1, (1, 1, 5), (5, 5, 1))
660
+ assert_size_stride(arg15_1, (1, 1, 5, 5), (25, 25, 5, 1))
661
+ with torch.cuda._DeviceGuard(1):
662
+ torch.cuda.set_device(1)
663
+ buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
664
+ buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
665
+ buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
666
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
667
+ stream1 = get_raw_stream(1)
668
+ triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream1)
669
+ del arg10_1
670
+ del arg11_1
671
+ del arg1_1
672
+ del arg3_1
673
+ del arg5_1
674
+ del arg6_1
675
+ del arg9_1
676
+ del buf1
677
+ buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
678
+ # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
679
+ triton_poi_fused_mul_1_xnumel = 32*s37
680
+ stream1 = get_raw_stream(1)
681
+ triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream1)
682
+ del buf0
683
+ return (buf2, buf5, )
684
+
685
+ runner = Runner(partitions=[])
686
+ call = runner.call
687
+ recursively_apply_fns = runner.recursively_apply_fns
688
+
689
+
690
+ def benchmark_compiled_module(times=10, repeat=10):
691
+ from torch._dynamo.testing import rand_strided
692
+ from torch._inductor.utils import print_performance
693
+ arg0_1 = 528
694
+ arg1_1 = rand_strided((1, 32, 528, 128), (2162688, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16)
695
+ arg2_1 = 528
696
+ arg3_1 = rand_strided((1, 8, 528, 128), (540672, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16)
697
+ arg4_1 = 528
698
+ arg5_1 = rand_strided((1, 8, 528, 128), (540672, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16)
699
+ arg6_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:1', dtype=torch.int32)
700
+ arg7_1 = 528
701
+ arg8_1 = 528
702
+ arg9_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:1', dtype=torch.int32)
703
+ arg10_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:1', dtype=torch.int32)
704
+ arg11_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:1', dtype=torch.int32)
705
+ arg12_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:1', dtype=torch.int32)
706
+ arg13_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:1', dtype=torch.int32)
707
+ arg14_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:1', dtype=torch.int32)
708
+ arg15_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:1', dtype=torch.int32)
709
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1])
710
+ return print_performance(fn, times=times, repeat=repeat)
711
+
712
+
713
+ if __name__ == "__main__":
714
+ from torch._inductor.wrapper_benchmark import compiled_module_main
715
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/3q/c3qbvcsx2w7qss2v3eocuadgz6t35joo33bflzqkxzzj747zcjpk.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 65536, 'r0_': 128},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 128
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = (xindex % ks0)
28
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
29
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
30
+ x3 = xindex
31
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
32
+ r0_index = r0_offset + r0_base
33
+ r0_mask = r0_index < r0_numel
34
+ roffset = r0_offset
35
+ rindex = r0_index
36
+ r0_2 = r0_index
37
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
38
+ tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
39
+ tmp2 = tmp0 * tmp1
40
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
41
+ tmp5 = _tmp4 + tmp3
42
+ _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
43
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
44
+ tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
45
+ tmp6 = tmp4.to(tl.float32)
46
+ tmp8 = 0.6931471805599453
47
+ tmp9 = tmp7 * tmp8
48
+ tmp10 = 1.4426950408889634
49
+ tmp11 = tmp9 * tmp10
50
+ tmp12 = tmp6 - tmp11
51
+ tl.store(out_ptr1 + (x3), tmp12, xmask)
progress/SpecForge/cache/compiled_kernels/3q/fc5920467dd1501963c976e2b895fc37747fdebfa098fff912209055f3a31828.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "2685c2d349c32243d4ee216505dfdf1e257d04d8316595ed69d4ca3499146788", "found_by_coordesc": false, "time_taken_ms": 53, "triton_cache_hash": "GBIQTIXLLLI56EMJONBW74RZJ42E6PTSU5N7LA23N4VBEBKK3HNQ"}
progress/SpecForge/cache/compiled_kernels/3r/c3rkwwyedldrjz6sidtx5huqcsdgpdpu4xndmm6h4e4boo6cbg2w.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['0_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/nj/cnjtse3xftpnmqvwojj6g7ajl3r3hvxbz3sgyaaznnrxcs7gzj2e.py
38
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
39
+ # Source node to ATen node mapping:
40
+ # flex_attention => flex_attention
41
+ # Graph fragment:
42
+ # %arg0_1 : Tensor "bf16[1, 32, 976, 128][3997696, 128, 4096, 1]cuda:7" = PlaceHolder[target=arg0_1]
43
+ # %arg1_1 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg1_1]
44
+ # %arg2_1 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg2_1]
45
+ # %getitem_1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=getitem_1]
46
+ # %buf1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=buf1]
47
+ # %arg3_1 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=arg3_1]
48
+ # %arg4_1 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=arg4_1]
49
+ # %arg5_1 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=arg5_1]
50
+ # %arg6_1 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=arg6_1]
51
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (976, 976, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
52
+ # return %getitem
53
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
54
+ import triton
55
+ import triton.language as tl
56
+
57
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
58
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
59
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
60
+
61
+ @triton_heuristics.template(
62
+
63
+ num_stages=3,
64
+ num_warps=8,
65
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
67
+
68
+ )
69
+ @triton.jit
70
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0):
71
+ PRESCALE_QK : tl.constexpr = False
72
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
73
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
74
+ WRITE_DQ : tl.constexpr = True
75
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
76
+ OUTPUT_MAX : tl.constexpr = False
77
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
78
+ IS_DIVISIBLE : tl.constexpr = False
79
+ SM_SCALE : tl.constexpr = 0.08838834764831845
80
+ GQA_SHARED_HEADS : tl.constexpr = 4
81
+ HAS_FULL_BLOCKS : tl.constexpr = True
82
+ QK_HEAD_DIM : tl.constexpr = 128
83
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
84
+ V_HEAD_DIM : tl.constexpr = 128
85
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ SAFE_HEAD_DIM : tl.constexpr = True
87
+ USE_TMA : tl.constexpr = False
88
+ BLOCK_M : tl.constexpr = 128
89
+ BLOCK_N : tl.constexpr = 64
90
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
91
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
92
+ INDEX_DTYPE : tl.constexpr = tl.int32
93
+ Q = arg_Q
94
+ K = arg_K
95
+ V = arg_V
96
+ LSE = arg_LSE
97
+ MAX = arg_MAX
98
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
99
+ KV_IDX = arg_KV_IDX
100
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
101
+ FULL_KV_IDX = arg_FULL_KV_IDX
102
+
103
+ # Sub notation for this kernel:
104
+ #
105
+ # Q: Query, K: Key, V: Value
106
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
107
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
108
+ # V_HEAD_DIM: The dimension of the value embeddings
109
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
110
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
111
+ #
112
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
113
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
114
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
115
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
116
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
117
+ #
118
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
119
+ #
120
+ # (Modifiable) Performance tuning options
121
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
122
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
123
+
124
+ # The below are kernel options that can be applied for certain score_mods,
125
+ # or involve a numerics vs. perf tradeoff
126
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
127
+ # about 20% more numerical error, but slightly faster.
128
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
129
+ # is not masked out? If so, we can skip an extra safety check
130
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
131
+ # contiguous? If so, we don't need to do an indirect jump for every block
132
+
133
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
134
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
135
+
136
+ # Define strides of inputs
137
+ stride_qz, stride_qh, stride_qm, stride_qk = 3997696, 128, 4096, 1
138
+ stride_kz, stride_kh, stride_kn, stride_kk = 999424, 128, 1024, 1
139
+ stride_vz, stride_vh, stride_vn, stride_vk = 999424, 128, 1024, 1
140
+
141
+ ZQ = 1
142
+ HQ = 32
143
+ Q_LEN = 976
144
+ ZKV = 1
145
+ KV_LEN = 976
146
+
147
+ MATMUL_PRECISION = Q.dtype.element_ty
148
+
149
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
150
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
151
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
152
+
153
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
154
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
155
+ off_zkv = off_zq % ZKV
156
+ off_hkv = off_hq // GQA_SHARED_HEADS
157
+ off_g = off_hq % GQA_SHARED_HEADS
158
+
159
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
160
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
161
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
162
+
163
+ Q = Q + q_offset
164
+ K = K + k_offset
165
+ V = V + v_offset
166
+
167
+ # Setting up the TMA descriptors for Q, K, V
168
+ desc_q = None
169
+ desc_k = None
170
+ desc_v = None
171
+
172
+ SPARSE_Z = 1
173
+ SPARSE_HQ = 1
174
+
175
+ sparse_idx_z = off_zq % SPARSE_Z
176
+ sparse_idx_hq = off_hq % SPARSE_HQ
177
+
178
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
179
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
180
+
181
+ stride_kv_num_blks_h = 8
182
+ stride_kv_idx_h = 64
183
+ stride_kv_idx_m = 8
184
+
185
+ # initialize pointer to m and l
186
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
187
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
188
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
189
+
190
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
191
+
192
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
193
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
194
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
195
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
196
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
197
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
198
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
199
+
200
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
201
+ # We don't know anything "special" about these blocks, so we need to apply
202
+ # both score_mod and mask_mod to it
203
+ kv_indices = KV_IDX + sparse_kv_idx_offset
204
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
205
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
206
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
207
+
208
+
209
+ # K and V pointers will be passed directly to forward_inner
210
+
211
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
212
+
213
+
214
+ acc, l_i, m_i = forward_inner(
215
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
216
+ q, K, V,
217
+ desc_k, desc_v, Q_LEN, KV_LEN,
218
+ acc, l_i, m_i,
219
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
220
+ kv_start,
221
+ kv_indices, kv_num_blocks,
222
+ 0, block_n_end,
223
+ MATMUL_PRECISION,
224
+ stride_kk, stride_kn, stride_vn, stride_vk,
225
+ IS_FULL_BLOCKS=False,
226
+ )
227
+
228
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
229
+ # We know these blocks are guaranteed to be "full", so we don't need to
230
+ # apply mask_mod to them - only score_mod
231
+ if HAS_FULL_BLOCKS:
232
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
233
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
234
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
235
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
236
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
237
+ # K and V pointers will be passed directly to forward_inner
238
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
239
+
240
+ acc, l_i, m_i = forward_inner(
241
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
242
+ q, K, V,
243
+ desc_k, desc_v, Q_LEN, KV_LEN,
244
+ acc, l_i, m_i,
245
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
246
+ kv_start,
247
+ kv_indices, kv_num_blocks,
248
+ 0, block_n_end,
249
+ MATMUL_PRECISION,
250
+ stride_kk, stride_kn, stride_vn, stride_vk,
251
+ IS_FULL_BLOCKS=True,
252
+ )
253
+
254
+
255
+ # [Note] Handle fully masked out rows:
256
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
257
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
258
+ l_i = tl.where(l_i == 0.0, 1, l_i)
259
+
260
+ acc = acc / l_i[:, None]
261
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
262
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
263
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
264
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
265
+
266
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
267
+
268
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
269
+ xindex = idx_d + 128*idx_m + 124928*idx_hq + 3997696*idx_zq
270
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
271
+
272
+ if OUTPUT_LOGSUMEXP:
273
+ off_hz = off_zq * HQ + off_hq
274
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
275
+ lse = m_i + tl.math.log2(l_i)
276
+ if IS_DIVISIBLE:
277
+ tl.store(l_ptrs, lse)
278
+ else:
279
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
280
+
281
+ if OUTPUT_MAX:
282
+ off_hz = off_zq * HQ + off_hq
283
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
284
+ if IS_DIVISIBLE:
285
+ tl.store(max_ptrs, m_i)
286
+ else:
287
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
288
+
289
+
290
+ # Utility triton funcs
291
+ @triton.jit
292
+ def get_offset_for_next_block(
293
+ loop_iter, col_indices, total_blocks,
294
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
295
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
296
+ ):
297
+ if BLOCKS_ARE_CONTIGUOUS:
298
+ return BLOCK
299
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
300
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
301
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
302
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
303
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
304
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
305
+ return offset
306
+
307
+ @triton.jit
308
+ def get_bounded_indices(indices, max_len=None):
309
+ return indices % max_len if max_len is not None else indices
310
+
311
+ @triton.jit
312
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
313
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
314
+ return tl.load(block_ptr)
315
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
317
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
319
+ else:
320
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
321
+
322
+ @triton.jit
323
+ def load_checked_2d(
324
+ ptr,
325
+ offs_m,
326
+ offs_n,
327
+ stride_m,
328
+ stride_n,
329
+ IS_DIVISIBLE_M: tl.constexpr,
330
+ IS_DIVISIBLE_N: tl.constexpr,
331
+ M_LEN: tl.constexpr,
332
+ N_LEN: tl.constexpr,
333
+ ):
334
+ # Calculate final pointer if strides are provided
335
+ if stride_m is not None and stride_n is not None:
336
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
337
+
338
+ # Handle all masking cases
339
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
340
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
341
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
343
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
345
+ else: # Both divisible
346
+ return tl.load(ptr)
347
+
348
+
349
+ # Common Imports
350
+ @triton.jit
351
+ def forward_block_mn(
352
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
353
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
354
+ # accumulated values
355
+ acc, l_i, m_i,
356
+ # Offsets
357
+ off_z, off_h, offs_m, offs_n,
358
+ # Offsets needed for TMA loads
359
+ kv_start,
360
+ kv_offset,
361
+ MATMUL_PRECISION, RCP_LN2,
362
+ # Strides for K and V
363
+ stride_kk, stride_kn, stride_vn, stride_vk,
364
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
365
+
366
+ ):
367
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
368
+ PRESCALE_QK : tl.constexpr = False
369
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
370
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
371
+ WRITE_DQ : tl.constexpr = True
372
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
373
+ OUTPUT_MAX : tl.constexpr = False
374
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
375
+ IS_DIVISIBLE : tl.constexpr = False
376
+ SM_SCALE : tl.constexpr = 0.08838834764831845
377
+ GQA_SHARED_HEADS : tl.constexpr = 4
378
+ HAS_FULL_BLOCKS : tl.constexpr = True
379
+ QK_HEAD_DIM : tl.constexpr = 128
380
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
381
+ V_HEAD_DIM : tl.constexpr = 128
382
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ SAFE_HEAD_DIM : tl.constexpr = True
384
+ USE_TMA : tl.constexpr = False
385
+ BLOCK_M : tl.constexpr = 128
386
+ BLOCK_N : tl.constexpr = 64
387
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
388
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
389
+ INDEX_DTYPE : tl.constexpr = tl.int32
390
+
391
+
392
+ # -- load k --
393
+ # NB reversed order to since K is transposed
394
+ kv_base_offset = kv_start + kv_offset
395
+
396
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
397
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
398
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
399
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
400
+
401
+ k = tl.trans(k)
402
+ # -- compute qk ---
403
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
404
+ if not PRESCALE_QK:
405
+ qk *= SM_SCALE
406
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
407
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
408
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
409
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
410
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
411
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
412
+
413
+ tmp0 = (qk)
414
+ post_mod_scores = tmp0
415
+
416
+
417
+ if CHECK_BLOCK_BOUNDARY:
418
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
419
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
420
+
421
+ if not IS_FULL_BLOCKS:
422
+ tmp1 = (m)
423
+ tmp2 = tl.full([1], 0, tl.int32)
424
+ tmp3 = tmp1 < tmp2
425
+ tmp4 = (n)
426
+ tmp5 = tmp4 <= tmp1
427
+ tmp6 = tmp3 & tmp5
428
+ tmp7 = tmp1 >= tmp2
429
+ tmp8 = tmp4 < tmp2
430
+ tmp9 = tmp7 & tmp8
431
+ tmp10 = tmp8 == 0
432
+ tmp11 = tmp7 & tmp10
433
+ tmp12 = tmp1 - tmp2
434
+ tmp13 = tl.full([1], 16, tl.int32)
435
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
436
+ tmp15 = tmp4 - tmp2
437
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
438
+ tmp17 = tmp14 == tmp16
439
+ tmp18 = tmp11 & tmp17
440
+ tmp19 = tmp9 | tmp18
441
+ tmp20 = tmp6 | tmp19
442
+ mask_mod_output = tmp20
443
+
444
+
445
+ if CHECK_BLOCK_BOUNDARY:
446
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
447
+ # apply mask for partially unmasked blocks
448
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
449
+
450
+ if not PRESCALE_QK:
451
+ post_mod_scores *= RCP_LN2
452
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
453
+
454
+ # -- compute scaling constant ---
455
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
456
+ if not ROWS_GUARANTEED_SAFE:
457
+ masked_out_rows = (m_ij == float("-inf"))
458
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
459
+ else:
460
+ m_ij_masked = m_ij
461
+
462
+ alpha = tl.math.exp2(m_i - m_ij_masked)
463
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
464
+
465
+ # NB: l_i update is pulled up here since it's a bit faster
466
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
467
+ # m_ij
468
+ l_i = l_i * alpha + tl.sum(p, 1)
469
+ # # -- scale and update acc --
470
+ acc = acc * alpha[:, None]
471
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
472
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
473
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
474
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
475
+
476
+ # -- update m_i
477
+ m_i = m_ij
478
+
479
+ return acc, l_i, m_i
480
+
481
+ @triton.jit
482
+ def forward_inner(
483
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
484
+ q, K, V,
485
+ desc_k, desc_v, Q_LEN, KV_LEN,
486
+ # accumulated values
487
+ acc, l_i, m_i,
488
+ # Offsets used as inputs to score_mod & mask_mod
489
+ # of size [BLOCK_M, BLOCK_N] or scalar.
490
+ off_z, off_h, offs_m, offs_n,
491
+ # Offsets needed for TMA loads
492
+ kv_start,
493
+ # blocksparse data
494
+ kv_indices, kv_num_blocks,
495
+ # start kv and end kv block
496
+ block_n_start, block_n_end,
497
+ MATMUL_PRECISION,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ ):
502
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
503
+ PRESCALE_QK : tl.constexpr = False
504
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
505
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
506
+ WRITE_DQ : tl.constexpr = True
507
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
508
+ OUTPUT_MAX : tl.constexpr = False
509
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
510
+ IS_DIVISIBLE : tl.constexpr = False
511
+ SM_SCALE : tl.constexpr = 0.08838834764831845
512
+ GQA_SHARED_HEADS : tl.constexpr = 4
513
+ HAS_FULL_BLOCKS : tl.constexpr = True
514
+ QK_HEAD_DIM : tl.constexpr = 128
515
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
516
+ V_HEAD_DIM : tl.constexpr = 128
517
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
518
+ SAFE_HEAD_DIM : tl.constexpr = True
519
+ USE_TMA : tl.constexpr = False
520
+ BLOCK_M : tl.constexpr = 128
521
+ BLOCK_N : tl.constexpr = 64
522
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
523
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
524
+ INDEX_DTYPE : tl.constexpr = tl.int32
525
+
526
+
527
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
528
+ RCP_LN2: tl.constexpr = 1.44269504
529
+
530
+ if PRESCALE_QK:
531
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
532
+
533
+ kv_offset = 0
534
+
535
+ # loop over k, v and update accumulator until block_n_end
536
+ for start_n in range(block_n_start, block_n_end):
537
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
538
+ if IS_DIVISIBLE:
539
+ acc, l_i, m_i = forward_block_mn(
540
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
541
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
542
+ # accumulated values
543
+ acc, l_i, m_i,
544
+ # Offsets
545
+ off_z, off_h, offs_m, offs_n,
546
+ # Offsets needed for TMA loads
547
+ kv_start,
548
+ kv_offset,
549
+ MATMUL_PRECISION, RCP_LN2,
550
+ # Strides for K and V
551
+ stride_kk, stride_kn, stride_vn, stride_vk,
552
+ IS_FULL_BLOCKS,
553
+ )
554
+ else:
555
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
556
+ # it's on par or slightly faster than only applying to the last block in fwd.
557
+ # However, we choose different strategy for bwd, where we only apply mod & mask
558
+ # to the last block because it's faster a lot.
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
573
+ )
574
+
575
+
576
+
577
+ offset = get_offset_for_next_block(
578
+ start_n, kv_indices, kv_num_blocks,
579
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
580
+ )
581
+
582
+ offs_n = offs_n + offset
583
+ kv_offset += offset
584
+
585
+
586
+ return acc, l_i, m_i
587
+ ''', device_str='cuda')
588
+
589
+
590
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py
591
+ # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
592
+ # Source node to ATen node mapping:
593
+ # lse_scaled => mul
594
+ # Graph fragment:
595
+ # %buf3 : Tensor = PlaceHolder[target=buf3]
596
+ # %mul : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
597
+ # return %mul
598
+ triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', '''
599
+ import triton
600
+ import triton.language as tl
601
+
602
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
603
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
604
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
605
+ triton_helpers.set_driver_to_gpu()
606
+
607
+ @triton_heuristics.pointwise(
608
+ size_hints={'x': 32768},
609
+ filename=__file__,
610
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
611
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 249856}},
612
+ min_elem_per_thread=0
613
+ )
614
+ @triton.jit
615
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
616
+ xnumel = 31232
617
+ xoffset = tl.program_id(0) * XBLOCK
618
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
619
+ xmask = xindex < xnumel
620
+ x0 = xindex
621
+ tmp0 = tl.load(in_ptr0 + (x0), xmask)
622
+ tmp1 = 0.6931471805599453
623
+ tmp2 = tmp0 * tmp1
624
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
625
+ ''', device_str='cuda')
626
+
627
+
628
+ async_compile.wait(globals())
629
+ del async_compile
630
+
631
+ class Runner:
632
+ def __init__(self, partitions):
633
+ self.partitions = partitions
634
+
635
+ def recursively_apply_fns(self, fns):
636
+ new_callables = []
637
+ for fn, c in zip(fns, self.partitions):
638
+ new_callables.append(fn(c))
639
+ self.partitions = new_callables
640
+
641
+ def call(self, args):
642
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args
643
+ args.clear()
644
+ assert_size_stride(arg0_1, (1, 32, 976, 128), (3997696, 128, 4096, 1))
645
+ assert_size_stride(arg1_1, (1, 8, 976, 128), (999424, 128, 1024, 1))
646
+ assert_size_stride(arg2_1, (1, 8, 976, 128), (999424, 128, 1024, 1))
647
+ assert_size_stride(arg3_1, (1, 1, 8), (8, 8, 1))
648
+ assert_size_stride(arg4_1, (1, 1, 8, 8), (64, 64, 8, 1))
649
+ assert_size_stride(arg5_1, (1, 1, 8), (8, 8, 1))
650
+ assert_size_stride(arg6_1, (1, 1, 8, 8), (64, 64, 8, 1))
651
+ assert_size_stride(arg7_1, (1, 1, 8), (8, 8, 1))
652
+ assert_size_stride(arg8_1, (1, 1, 8, 8), (64, 64, 8, 1))
653
+ assert_size_stride(arg9_1, (1, 1, 8), (8, 8, 1))
654
+ assert_size_stride(arg10_1, (1, 1, 8, 8), (64, 64, 8, 1))
655
+ with torch.cuda._DeviceGuard(7):
656
+ torch.cuda.set_device(7)
657
+ buf0 = empty_strided_cuda((1, 32, 976), (31232, 976, 1), torch.float32)
658
+ buf1 = empty_strided_cuda((1, 32, 976), (31232, 976, 1), torch.float32)
659
+ buf2 = empty_strided_cuda((1, 32, 976, 128), (3997696, 128, 4096, 1), torch.bfloat16)
660
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
661
+ stream7 = get_raw_stream(7)
662
+ triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 8, 1, 32, stream=stream7)
663
+ del arg0_1
664
+ del arg1_1
665
+ del arg2_1
666
+ del arg3_1
667
+ del arg4_1
668
+ del arg5_1
669
+ del arg6_1
670
+ buf5 = buf1; del buf1 # reuse
671
+ # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
672
+ stream7 = get_raw_stream(7)
673
+ triton_poi_fused_mul_1.run(buf0, buf5, 31232, stream=stream7)
674
+ del buf0
675
+ return (buf2, buf5, )
676
+
677
+ runner = Runner(partitions=[])
678
+ call = runner.call
679
+ recursively_apply_fns = runner.recursively_apply_fns
680
+
681
+
682
+ def benchmark_compiled_module(times=10, repeat=10):
683
+ from torch._dynamo.testing import rand_strided
684
+ from torch._inductor.utils import print_performance
685
+ arg0_1 = rand_strided((1, 32, 976, 128), (3997696, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
686
+ arg1_1 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
687
+ arg2_1 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
688
+ arg3_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32)
689
+ arg4_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32)
690
+ arg5_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32)
691
+ arg6_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32)
692
+ arg7_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32)
693
+ arg8_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32)
694
+ arg9_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32)
695
+ arg10_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32)
696
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1])
697
+ return print_performance(fn, times=times, repeat=repeat)
698
+
699
+
700
+ if __name__ == "__main__":
701
+ from torch._inductor.wrapper_benchmark import compiled_module_main
702
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/3z/c3zi2pt6zmbthc6ythgt5p4ednhp6m24gpscb2pt6adf6xojetua.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
101
+
102
+ ZQ = 1
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = ks0
106
+ ZKV = 1
107
+ KV_LEN = ks1
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 1
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 1
148
+ stride_kv_idx_h = 1
149
+ stride_kv_idx_m = 1
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 1
245
+ stride_q_idx_h = 1
246
+ stride_q_idx_n = 1
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
345
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831845
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = ks0
385
+ KV_LEN = ks1
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831845
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = (m)
480
+ tmp2 = tl.full([1], 0, tl.int32)
481
+ tmp3 = tmp1 < tmp2
482
+ tmp4 = (n)
483
+ tmp5 = tmp4 <= tmp1
484
+ tmp6 = tmp3 & tmp5
485
+ tmp7 = tmp1 >= tmp2
486
+ tmp8 = tmp4 < tmp2
487
+ tmp9 = tmp7 & tmp8
488
+ tmp10 = tmp8 == 0
489
+ tmp11 = tmp7 & tmp10
490
+ tmp12 = tmp1 - tmp2
491
+ tmp13 = tl.full([1], 16, tl.int32)
492
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
493
+ tmp15 = tmp4 - tmp2
494
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
495
+ tmp17 = tmp14 == tmp16
496
+ tmp18 = tmp11 & tmp17
497
+ tmp19 = tmp9 | tmp18
498
+ tmp20 = tmp6 | tmp19
499
+ mask_mod_output = tmp20
500
+
501
+
502
+ # apply mask for partial masked block
503
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
504
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
505
+ if not PRESCALE_QK:
506
+ post_mod_scores *= RCP_LN2
507
+ p = tl.math.exp2(post_mod_scores - lse)
508
+ # Compute dP and dS.
509
+ # NB reversed order to since V is transposed
510
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
511
+
512
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
513
+ ds = p * (dp - Di[:, None])
514
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
515
+ tmp21 = (ds)
516
+ grad_scores = tmp21
517
+
518
+
519
+ if not IS_DIVISIBLE:
520
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
521
+
522
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
523
+ if WRITE_DQ:
524
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
525
+
526
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
527
+ ds = grad_scores
528
+
529
+ if not IS_FULL_BLOCKS:
530
+ # (grads) apply mask for partially unmasked block
531
+ ds = tl.where(mask_mod_output, ds, 0.0)
532
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
533
+ ds = ds.to(MATMUL_PRECISION)
534
+ # Compute dQ.
535
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
536
+
537
+ return dq
538
+
539
+
540
+ @triton.jit
541
+ def bwd_dkdv_inner(
542
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
543
+ Q, DO, DELTA, LSE, # pointers
544
+ dk, dv, k, v,
545
+ off_z, off_hq, offs_n1, offs_m1,
546
+ stride_qm, stride_qd, stride_dom, stride_dod,
547
+ q_indices, sparse_q_num_blocks,
548
+ MATMUL_PRECISION,
549
+ IS_FULL_BLOCKS,
550
+ ):
551
+ PRESCALE_QK : tl.constexpr = False
552
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
553
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
554
+ WRITE_DQ : tl.constexpr = True
555
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
556
+ OUTPUT_MAX : tl.constexpr = False
557
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
558
+ IS_DIVISIBLE : tl.constexpr = False
559
+ SM_SCALE : tl.constexpr = 0.08838834764831845
560
+ GQA_SHARED_HEADS : tl.constexpr = 4
561
+ HAS_FULL_BLOCKS : tl.constexpr = True
562
+ QK_HEAD_DIM : tl.constexpr = 128
563
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
564
+ V_HEAD_DIM : tl.constexpr = 128
565
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
566
+ SAFE_HEAD_DIM : tl.constexpr = True
567
+ BLOCK_M1 : tl.constexpr = 64
568
+ BLOCK_N1 : tl.constexpr = 128
569
+ BLOCK_M2 : tl.constexpr = 128
570
+ BLOCK_N2 : tl.constexpr = 64
571
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
572
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
573
+ INDEX_DTYPE : tl.constexpr = tl.int32
574
+
575
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
576
+ RCP_LN2: tl.constexpr = 1.44269504
577
+ Q_LEN = ks0
578
+ KV_LEN = ks1
579
+
580
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
581
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
582
+
583
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
584
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
585
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
586
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
587
+
588
+ # The minimum is needed to handle the case where we run with a super large
589
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
590
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
591
+
592
+ for start_m in range(0, hi):
593
+ dk, dv = bwd_dkdv_block_mn(
594
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
595
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
596
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
597
+ stride_qm, stride_qd, stride_dom, stride_dod,
598
+ q_indices, sparse_q_num_blocks,
599
+ MATMUL_PRECISION, RCP_LN2,
600
+ IS_FULL_BLOCKS,
601
+ )
602
+ # Increment pointers.
603
+ offset = get_offset_for_next_block(
604
+ start_m, q_indices, sparse_q_num_blocks,
605
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
606
+ )
607
+
608
+ qT_ptrs += offset * stride_qm
609
+ do_ptrs += offset * stride_dom
610
+ offs_m1 += offset
611
+
612
+ return dk, dv
613
+
614
+
615
+ @triton.jit
616
+ def bwd_dkdv_block_mn(
617
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
618
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
619
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
620
+ stride_qm, stride_qd, stride_dom, stride_dod,
621
+ q_indices, sparse_q_num_blocks,
622
+ MATMUL_PRECISION, RCP_LN2,
623
+ IS_FULL_BLOCKS,
624
+ ):
625
+ PRESCALE_QK : tl.constexpr = False
626
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
627
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
628
+ WRITE_DQ : tl.constexpr = True
629
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
630
+ OUTPUT_MAX : tl.constexpr = False
631
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
632
+ IS_DIVISIBLE : tl.constexpr = False
633
+ SM_SCALE : tl.constexpr = 0.08838834764831845
634
+ GQA_SHARED_HEADS : tl.constexpr = 4
635
+ HAS_FULL_BLOCKS : tl.constexpr = True
636
+ QK_HEAD_DIM : tl.constexpr = 128
637
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
638
+ V_HEAD_DIM : tl.constexpr = 128
639
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
640
+ SAFE_HEAD_DIM : tl.constexpr = True
641
+ BLOCK_M1 : tl.constexpr = 64
642
+ BLOCK_N1 : tl.constexpr = 128
643
+ BLOCK_M2 : tl.constexpr = 128
644
+ BLOCK_N2 : tl.constexpr = 64
645
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
646
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
647
+ INDEX_DTYPE : tl.constexpr = tl.int32
648
+
649
+
650
+ # NB reversed order since Q is transposed
651
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
652
+ # Load LSE before computing qk to reduce pipeline stall.
653
+ if IS_DIVISIBLE:
654
+ lse = tl.load(LSE + offs_m1)
655
+ else:
656
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
657
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
658
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
659
+ if not PRESCALE_QK:
660
+ qkT *= SM_SCALE
661
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
662
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
663
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
664
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
665
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
666
+
667
+ pre_mod_scores = qkT
668
+ tmp22 = (qkT)
669
+ post_mod_scores = tmp22
670
+
671
+
672
+
673
+ if not IS_DIVISIBLE:
674
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
675
+
676
+ if not IS_FULL_BLOCKS:
677
+ tmp23 = (m)
678
+ tmp24 = tl.full([1], 0, tl.int32)
679
+ tmp25 = tmp23 < tmp24
680
+ tmp26 = (n)
681
+ tmp27 = tmp26 <= tmp23
682
+ tmp28 = tmp25 & tmp27
683
+ tmp29 = tmp23 >= tmp24
684
+ tmp30 = tmp26 < tmp24
685
+ tmp31 = tmp29 & tmp30
686
+ tmp32 = tmp30 == 0
687
+ tmp33 = tmp29 & tmp32
688
+ tmp34 = tmp23 - tmp24
689
+ tmp35 = tl.full([1], 16, tl.int32)
690
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
691
+ tmp37 = tmp26 - tmp24
692
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
693
+ tmp39 = tmp36 == tmp38
694
+ tmp40 = tmp33 & tmp39
695
+ tmp41 = tmp31 | tmp40
696
+ tmp42 = tmp28 | tmp41
697
+ mask_mod_output = tmp42
698
+
699
+ # (grads) apply mask for fully masked block
700
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
701
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
702
+ if not PRESCALE_QK:
703
+ post_mod_scores *= RCP_LN2
704
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
705
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
706
+ # Compute dV.
707
+ ppT = pT
708
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
709
+ if IS_DIVISIBLE:
710
+ Di = tl.load(DELTA + offs_m1)
711
+ else:
712
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
713
+ # Compute dP and dS.
714
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
715
+ dsT = pT * (dpT - Di[None, :])
716
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
717
+ tmp43 = (dsT)
718
+ grad_scores = tmp43
719
+
720
+
721
+
722
+ if not IS_DIVISIBLE:
723
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
724
+
725
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
726
+ if not WRITE_DQ:
727
+ idx_b = off_z
728
+ idx_h = off_hq
729
+ idx_m = m
730
+ idx_n = n
731
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
732
+
733
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
734
+ dsT = grad_scores
735
+ if not IS_FULL_BLOCKS:
736
+ # (grads) apply mask for partially unmasked block
737
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
738
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
739
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
740
+
741
+ return dk, dv
742
+
743
+ # Utility triton funcs
744
+ @triton.jit
745
+ def get_offset_for_next_block(
746
+ loop_iter, col_indices, total_blocks,
747
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
748
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
749
+ ):
750
+ if BLOCKS_ARE_CONTIGUOUS:
751
+ return BLOCK
752
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
753
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
754
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
755
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
756
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
757
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
758
+ return offset
759
+
760
+ @triton.jit
761
+ def get_bounded_indices(indices, max_len=None):
762
+ return indices % max_len if max_len is not None else indices
763
+
764
+ @triton.jit
765
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
766
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
767
+ return tl.load(block_ptr)
768
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
769
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
770
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
771
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
772
+ else:
773
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
774
+
775
+ @triton.jit
776
+ def load_checked_2d(
777
+ ptr,
778
+ offs_m,
779
+ offs_n,
780
+ stride_m,
781
+ stride_n,
782
+ IS_DIVISIBLE_M: tl.constexpr,
783
+ IS_DIVISIBLE_N: tl.constexpr,
784
+ M_LEN: tl.constexpr,
785
+ N_LEN: tl.constexpr,
786
+ ):
787
+ # Calculate final pointer if strides are provided
788
+ if stride_m is not None and stride_n is not None:
789
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
790
+
791
+ # Handle all masking cases
792
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
793
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
794
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
795
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
796
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
797
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
798
+ else: # Both divisible
799
+ return tl.load(ptr)
progress/SpecForge/cache/compiled_kernels/3z/c3zilfzjywngbdehwphwkhzpt6qcv6jecvzdajl2d5hb73xe6yzw.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=2,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ GQA_SHARED_HEADS : tl.constexpr = 4
28
+ HAS_FULL_BLOCKS : tl.constexpr = True
29
+ SM_SCALE : tl.constexpr = 0.08838834764831845
30
+ SPLIT_KV : tl.constexpr = 32
31
+ QK_HEAD_DIM : tl.constexpr = 128
32
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
33
+ V_HEAD_DIM : tl.constexpr = 128
34
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
35
+ SAFE_HEAD_DIM : tl.constexpr = True
36
+ BLOCK_M : tl.constexpr = 512
37
+ SAFE_M_BOUNDARY : tl.constexpr = False
38
+ SAFE_N_BOUNDARY : tl.constexpr = True
39
+ BLOCK_N : tl.constexpr = 64
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ USE_TMA : tl.constexpr = False
42
+ INDEX_DTYPE : tl.constexpr = tl.int32
43
+ Q = arg_Q
44
+ K = arg_K
45
+ V = arg_V
46
+ M = arg_M
47
+ L = arg_L
48
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
49
+ KV_IDX = arg_KV_IDX
50
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
51
+ FULL_KV_IDX = arg_FULL_KV_IDX
52
+
53
+ # Sub notation for this kernel:
54
+ # Q: Query, K: Key, V: Value
55
+ # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
56
+ # M: Number of queries, N: Number of keys/values
57
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
58
+ # V_HEAD_DIM: The dimension of the value embeddings
59
+ # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
60
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
61
+ # (Modifiable) Config options:
62
+ # SPLIT_KV: number of blocks K & V are split into
63
+ # TILE_KV: length of each local KV split
64
+ # BLOCK_M: block size that Q is padded along seqlen dim.
65
+ # BLOCK_N: block size of K & V along N dimension.
66
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
67
+ #
68
+ # change of base out of the loop
69
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
70
+ # is not masked out? If so, we can skip an extra safety check
71
+ # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
72
+ # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
73
+
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
75
+ #
76
+ # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
77
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
78
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
79
+ #
80
+ #
81
+ # Output: ACC output accumulated across local KV split.
82
+
83
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
84
+
85
+ # Define Q Strides
86
+ stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1
87
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
88
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
89
+ stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1
90
+ stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1
91
+
92
+
93
+ Z = 1
94
+ ZKV = 1
95
+ HKV = 8
96
+ G: tl.constexpr = GQA_SHARED_HEADS
97
+ HQ = HKV * G
98
+ Q_LEN = ks0
99
+ KV_LEN = ks1
100
+
101
+ MATMUL_PRECISION = Q.dtype.element_ty
102
+
103
+ # Make sure each split is a multiple of BLOCK_N
104
+ TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
105
+ TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
106
+ TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
107
+
108
+ off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV
109
+ off_zkv = off_z % ZKV
110
+ off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV
111
+ off_t = tl.program_id(1).to(INDEX_DTYPE)
112
+
113
+ q_offset = off_z * stride_qz + off_hkv * stride_qh
114
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
115
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
116
+
117
+ K = K + k_offset
118
+ V = V + v_offset
119
+
120
+ SPARSE_Z = 1
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_z % SPARSE_Z
124
+ sparse_idx_h = off_hkv % SPARSE_HQ
125
+
126
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
127
+ SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
128
+
129
+ # initialize pointer to m and l
130
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
131
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
132
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
133
+
134
+ # initialize offsets
135
+ tl.device_assert(BLOCK_M % G == 0)
136
+ BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
137
+ off_g = tl.arange(0, G) # [G]
138
+ offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
139
+ offs_hq = offs_g + off_hkv * G
140
+ off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
141
+ offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
142
+ offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
143
+ offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
144
+
145
+ # Get HZ offsets for KV_NUM_BLKS and KV_IDX
146
+ stride_block_z, stride_block_h, stride_block_row = 1, 1, 1
147
+ sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
148
+ stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1
149
+ sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
150
+
151
+ # Calculate KV blocks that belong this CTA.
152
+ block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
153
+ block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
154
+
155
+ q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
156
+
157
+ if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
158
+ q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
159
+ elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
160
+ q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
161
+ elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
162
+ q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
163
+ else:
164
+ q = tl.load(Q + q_offset + q_range)
165
+
166
+ q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
167
+
168
+
169
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
170
+ # find first kv block we are loading and the number of blocks we are loading
171
+ # Offset the kv_indices tensor by the correct batch and head
172
+ kv_indices = KV_IDX + sparse_idx_hz_offset
173
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
174
+ MAX_KV_IDX = 1
175
+ indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
176
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
177
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
178
+ # first kv block we're loading
179
+
180
+ # last valid block according to sparse mask
181
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
182
+
183
+ offs_n = tl.arange(0, BLOCK_N) + off_n
184
+
185
+ desc_k = None
186
+ desc_v = None
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
190
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
191
+ # accumulatd values
192
+ acc, l_i, m_i,
193
+ #offsets
194
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
195
+ off_n,
196
+ #block sparse data
197
+ kv_indices, kv_num_blocks,
198
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
199
+ MATMUL_PRECISION,
200
+ stride_kk, stride_kn, stride_vn, stride_vk,
201
+ IS_FULL_BLOCKS=False,
202
+ )
203
+
204
+
205
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
206
+ # We know these blocks are guaranteed to be "full", so we don't need to
207
+ # apply mask_mod to them - only score_mod
208
+ if HAS_FULL_BLOCKS:
209
+ kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
210
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
211
+ # Assign full block in a reverse order for off_t. Prioritize the last CTA.
212
+ block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
213
+ block_n_end = block_n_start + TILE_KV_MULTIPLE
214
+ indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
215
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
216
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
217
+
218
+ # last valid block according to sparse mask
219
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
220
+
221
+ offs_n = tl.arange(0, BLOCK_N) + off_n
222
+
223
+ acc, l_i, m_i = forward_inner(
224
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
225
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
226
+ # accumulatd values
227
+ acc, l_i, m_i,
228
+ #offsets
229
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
230
+ off_n,
231
+ #block sparse data
232
+ kv_indices, kv_num_blocks,
233
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
234
+ MATMUL_PRECISION,
235
+ stride_kk, stride_kn, stride_vn, stride_vk,
236
+ IS_FULL_BLOCKS=True,
237
+ )
238
+
239
+ m_offset = off_t * stride_mt + off_z * stride_mz
240
+ l_offset = off_t * stride_lt + off_z * stride_lz
241
+
242
+ M_block_ptr = tl.make_block_ptr(
243
+ base=M + m_offset,
244
+ shape=(G, Q_LEN), # (G, M)
245
+ strides=(stride_mh, stride_mm),
246
+ offsets=(off_hkv*G, 0),
247
+ block_shape=(G, BLOCK_M_PER_HQ),
248
+ order=(1, 0)
249
+ )
250
+ L_block_ptr = tl.make_block_ptr(
251
+ base=L + l_offset,
252
+ shape=(G, Q_LEN), # (G, M)
253
+ strides=(stride_lh, stride_lm),
254
+ offsets=(off_hkv*G, 0),
255
+ block_shape=(G, BLOCK_M_PER_HQ),
256
+ order=(1, 0)
257
+ )
258
+
259
+ # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
260
+ m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
261
+ l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
262
+ if SAFE_M_BOUNDARY:
263
+ tl.store(M_block_ptr, m_i)
264
+ tl.store(L_block_ptr, l_i)
265
+ else:
266
+ tl.store(M_block_ptr, m_i, boundary_check=(1,))
267
+ tl.store(L_block_ptr, l_i, boundary_check=(1,))
268
+
269
+ # -- store output
270
+ idx_z = off_z
271
+ idx_t = off_t
272
+ idx_hq = off_hkv*G + off_g[:, None, None]
273
+ idx_m = off_m[None, :, None]
274
+ idx_d = offs_vd[None, None, :]
275
+
276
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
277
+ acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
278
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0
279
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask)
280
+
281
+
282
+ # Utility triton funcs
283
+ @triton.jit
284
+ def get_offset_for_next_block(
285
+ loop_iter, col_indices, total_blocks,
286
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
287
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
288
+ ):
289
+ if BLOCKS_ARE_CONTIGUOUS:
290
+ return BLOCK
291
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
292
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
293
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
294
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
295
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
296
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
297
+ return offset
298
+
299
+ @triton.jit
300
+ def get_bounded_indices(indices, max_len=None):
301
+ return indices % max_len if max_len is not None else indices
302
+
303
+ @triton.jit
304
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
305
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
306
+ return tl.load(block_ptr)
307
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
308
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
309
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
310
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
311
+ else:
312
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
313
+
314
+ @triton.jit
315
+ def load_checked_2d(
316
+ ptr,
317
+ offs_m,
318
+ offs_n,
319
+ stride_m,
320
+ stride_n,
321
+ IS_DIVISIBLE_M: tl.constexpr,
322
+ IS_DIVISIBLE_N: tl.constexpr,
323
+ M_LEN: tl.constexpr,
324
+ N_LEN: tl.constexpr,
325
+ ):
326
+ # Calculate final pointer if strides are provided
327
+ if stride_m is not None and stride_n is not None:
328
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
329
+
330
+ # Handle all masking cases
331
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
332
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
333
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
334
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
335
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
336
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
337
+ else: # Both divisible
338
+ return tl.load(ptr)
339
+
340
+
341
+ # Common Imports
342
+ @triton.jit
343
+ def forward_block_mn(
344
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
345
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
346
+ # accumulated values
347
+ acc, l_i, m_i,
348
+ # Offsets
349
+ off_z, off_h, offs_m, offs_n,
350
+ # Offsets needed for TMA loads
351
+ kv_start,
352
+ kv_offset,
353
+ MATMUL_PRECISION, RCP_LN2,
354
+ # Strides for K and V
355
+ stride_kk, stride_kn, stride_vn, stride_vk,
356
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
357
+
358
+ ):
359
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
360
+ PRESCALE_QK : tl.constexpr = False
361
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
362
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
363
+ WRITE_DQ : tl.constexpr = True
364
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
365
+ OUTPUT_MAX : tl.constexpr = False
366
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
367
+ IS_DIVISIBLE : tl.constexpr = False
368
+ GQA_SHARED_HEADS : tl.constexpr = 4
369
+ HAS_FULL_BLOCKS : tl.constexpr = True
370
+ SM_SCALE : tl.constexpr = 0.08838834764831845
371
+ SPLIT_KV : tl.constexpr = 32
372
+ QK_HEAD_DIM : tl.constexpr = 128
373
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
374
+ V_HEAD_DIM : tl.constexpr = 128
375
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
376
+ SAFE_HEAD_DIM : tl.constexpr = True
377
+ BLOCK_M : tl.constexpr = 512
378
+ SAFE_M_BOUNDARY : tl.constexpr = False
379
+ SAFE_N_BOUNDARY : tl.constexpr = True
380
+ BLOCK_N : tl.constexpr = 64
381
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
382
+ USE_TMA : tl.constexpr = False
383
+ INDEX_DTYPE : tl.constexpr = tl.int32
384
+
385
+
386
+ # -- load k --
387
+ # NB reversed order to since K is transposed
388
+ kv_base_offset = kv_start + kv_offset
389
+
390
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
391
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
392
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
393
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
394
+
395
+ k = tl.trans(k)
396
+ # -- compute qk ---
397
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
398
+ if not PRESCALE_QK:
399
+ qk *= SM_SCALE
400
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
401
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
402
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
403
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
404
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
405
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
406
+
407
+ tmp0 = (qk)
408
+ post_mod_scores = tmp0
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
413
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
414
+
415
+ if not IS_FULL_BLOCKS:
416
+ tmp1 = (m)
417
+ tmp2 = tl.full([1], 0, tl.int32)
418
+ tmp3 = tmp1 < tmp2
419
+ tmp4 = (n)
420
+ tmp5 = tmp4 <= tmp1
421
+ tmp6 = tmp3 & tmp5
422
+ tmp7 = tmp1 >= tmp2
423
+ tmp8 = tmp4 < tmp2
424
+ tmp9 = tmp7 & tmp8
425
+ tmp10 = tmp8 == 0
426
+ tmp11 = tmp7 & tmp10
427
+ tmp12 = tmp1 - tmp2
428
+ tmp13 = tl.full([1], 16, tl.int32)
429
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
430
+ tmp15 = tmp4 - tmp2
431
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
432
+ tmp17 = tmp14 == tmp16
433
+ tmp18 = tmp11 & tmp17
434
+ tmp19 = tmp9 | tmp18
435
+ tmp20 = tmp6 | tmp19
436
+ mask_mod_output = tmp20
437
+
438
+
439
+ if CHECK_BLOCK_BOUNDARY:
440
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
441
+ # apply mask for partially unmasked blocks
442
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
443
+
444
+ if not PRESCALE_QK:
445
+ post_mod_scores *= RCP_LN2
446
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
447
+
448
+ # -- compute scaling constant ---
449
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
450
+ if not ROWS_GUARANTEED_SAFE:
451
+ masked_out_rows = (m_ij == float("-inf"))
452
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
453
+ else:
454
+ m_ij_masked = m_ij
455
+
456
+ alpha = tl.math.exp2(m_i - m_ij_masked)
457
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
458
+
459
+ # NB: l_i update is pulled up here since it's a bit faster
460
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
461
+ # m_ij
462
+ l_i = l_i * alpha + tl.sum(p, 1)
463
+ # # -- scale and update acc --
464
+ acc = acc * alpha[:, None]
465
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
466
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
467
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
468
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
469
+
470
+ # -- update m_i
471
+ m_i = m_ij
472
+
473
+ return acc, l_i, m_i
474
+
475
+ @triton.jit
476
+ def forward_inner(
477
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
478
+ q, K, V,
479
+ desc_k, desc_v, Q_LEN, KV_LEN,
480
+ # accumulated values
481
+ acc, l_i, m_i,
482
+ # Offsets used as inputs to score_mod & mask_mod
483
+ # of size [BLOCK_M, BLOCK_N] or scalar.
484
+ off_z, off_h, offs_m, offs_n,
485
+ # Offsets needed for TMA loads
486
+ kv_start,
487
+ # blocksparse data
488
+ kv_indices, kv_num_blocks,
489
+ # start kv and end kv block
490
+ block_n_start, block_n_end,
491
+ MATMUL_PRECISION,
492
+ # Strides for K and V
493
+ stride_kk, stride_kn, stride_vn, stride_vk,
494
+ IS_FULL_BLOCKS,
495
+ ):
496
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
497
+ PRESCALE_QK : tl.constexpr = False
498
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
499
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
500
+ WRITE_DQ : tl.constexpr = True
501
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
502
+ OUTPUT_MAX : tl.constexpr = False
503
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
504
+ IS_DIVISIBLE : tl.constexpr = False
505
+ GQA_SHARED_HEADS : tl.constexpr = 4
506
+ HAS_FULL_BLOCKS : tl.constexpr = True
507
+ SM_SCALE : tl.constexpr = 0.08838834764831845
508
+ SPLIT_KV : tl.constexpr = 32
509
+ QK_HEAD_DIM : tl.constexpr = 128
510
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
511
+ V_HEAD_DIM : tl.constexpr = 128
512
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
513
+ SAFE_HEAD_DIM : tl.constexpr = True
514
+ BLOCK_M : tl.constexpr = 512
515
+ SAFE_M_BOUNDARY : tl.constexpr = False
516
+ SAFE_N_BOUNDARY : tl.constexpr = True
517
+ BLOCK_N : tl.constexpr = 64
518
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
519
+ USE_TMA : tl.constexpr = False
520
+ INDEX_DTYPE : tl.constexpr = tl.int32
521
+
522
+
523
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
524
+ RCP_LN2: tl.constexpr = 1.44269504
525
+
526
+ if PRESCALE_QK:
527
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
528
+
529
+ kv_offset = 0
530
+
531
+ # loop over k, v and update accumulator until block_n_end
532
+ for start_n in range(block_n_start, block_n_end):
533
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
534
+ if IS_DIVISIBLE:
535
+ acc, l_i, m_i = forward_block_mn(
536
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
537
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
538
+ # accumulated values
539
+ acc, l_i, m_i,
540
+ # Offsets
541
+ off_z, off_h, offs_m, offs_n,
542
+ # Offsets needed for TMA loads
543
+ kv_start,
544
+ kv_offset,
545
+ MATMUL_PRECISION, RCP_LN2,
546
+ # Strides for K and V
547
+ stride_kk, stride_kn, stride_vn, stride_vk,
548
+ IS_FULL_BLOCKS,
549
+ )
550
+ else:
551
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
552
+ # it's on par or slightly faster than only applying to the last block in fwd.
553
+ # However, we choose different strategy for bwd, where we only apply mod & mask
554
+ # to the last block because it's faster a lot.
555
+ acc, l_i, m_i = forward_block_mn(
556
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
557
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
558
+ # accumulated values
559
+ acc, l_i, m_i,
560
+ # Offsets
561
+ off_z, off_h, offs_m, offs_n,
562
+ # Offsets needed for TMA loads
563
+ kv_start,
564
+ kv_offset,
565
+ MATMUL_PRECISION, RCP_LN2,
566
+ # Strides for K and V
567
+ stride_kk, stride_kn, stride_vn, stride_vk,
568
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
569
+ )
570
+
571
+
572
+
573
+ offset = get_offset_for_next_block(
574
+ start_n, kv_indices, kv_num_blocks,
575
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
576
+ )
577
+
578
+ offs_n = offs_n + offset
579
+ kv_offset += offset
580
+
581
+
582
+ return acc, l_i, m_i
progress/SpecForge/cache/compiled_kernels/4a/7887d45b1aa6124e232769adbe995f9cc2af0dd187cb9928540172d82c7b8631.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "BZAXIZYYJGUVREZ5ANMEKVK5UU77TPVNED7QAB22EKNJIKFVURYA"}