Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- progress/SpecForge/.devcontainer/Dockerfile +32 -0
- progress/SpecForge/.devcontainer/devcontainer.json +30 -0
- progress/SpecForge/.github/CODEOWNERS +11 -0
- progress/SpecForge/.github/pull_request_template.md +30 -0
- progress/SpecForge/assets/logo.svg +0 -0
- progress/SpecForge/benchmarks/README.md +67 -0
- progress/SpecForge/benchmarks/__init__.py +3 -0
- progress/SpecForge/benchmarks/bench_eagle3.py +268 -0
- progress/SpecForge/benchmarks/benchmarker/__init__.py +29 -0
- progress/SpecForge/benchmarks/benchmarker/aime.py +133 -0
- progress/SpecForge/benchmarks/benchmarker/base.py +218 -0
- progress/SpecForge/benchmarks/benchmarker/ceval.py +267 -0
- progress/SpecForge/benchmarks/benchmarker/financeqa.py +59 -0
- progress/SpecForge/benchmarks/benchmarker/gpqa.py +85 -0
- progress/SpecForge/benchmarks/benchmarker/gsm8k.py +99 -0
- progress/SpecForge/benchmarks/benchmarker/humaneval.py +188 -0
- progress/SpecForge/benchmarks/benchmarker/livecodebench.py +46 -0
- progress/SpecForge/benchmarks/benchmarker/math500.py +122 -0
- progress/SpecForge/benchmarks/benchmarker/mmlu.py +82 -0
- progress/SpecForge/benchmarks/benchmarker/mmstar.py +185 -0
- progress/SpecForge/benchmarks/benchmarker/mtbench.py +59 -0
- progress/SpecForge/benchmarks/benchmarker/registry.py +31 -0
- progress/SpecForge/benchmarks/benchmarker/simpleqa.py +42 -0
- progress/SpecForge/benchmarks/benchmarker/utils.py +273 -0
- progress/SpecForge/cache/compiled_kernels/26/c26l7dxpqbfol7d62sqakxdv4rgyh27yhm4hrctevbkw5t6kekia.py +799 -0
- progress/SpecForge/cache/compiled_kernels/2d/c2d4e47kqxxnp6455gvkteqq3r336462zkbitosyeko6znxktn2b.py +879 -0
- progress/SpecForge/cache/compiled_kernels/2g/c2gswut4q57fp2ueybipg5qfqiy5coitofujwdnvqdwhr7nbvnyq.py +534 -0
- progress/SpecForge/cache/compiled_kernels/2j/4b74fa21eaaf86b6290185f6fe50aec9b905d858a087238ceddb52477f3f6acb.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/2j/c2j3mtk3thi6sn2hxiuhuigjw43spiu74mxdervpgpfrtos7u2qh.py +28 -0
- progress/SpecForge/cache/compiled_kernels/2n/c2ngvuchx6agpdr6v7awl3qgblaehfzaauoxn6camwvtk7syoxsk.py +715 -0
- progress/SpecForge/cache/compiled_kernels/2n/c2nooi7ekpz4qvmvghggbegd5cyfspb27jmq2snbi26zbrpoibnx.py +48 -0
- progress/SpecForge/cache/compiled_kernels/2n/d17ff4e7bb44e5ae89a267ef332bb7c074804ce0942fc0694c3ef15b05f7854a.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/2o/c2oashzxz74kzyuwo67tuhk32cike37ysabriftachdv7lf2qxgs.py +799 -0
- progress/SpecForge/cache/compiled_kernels/2v/c2vob47d7sxpitzmofyr55f5hvxsitxjhpyv5hdiqcdjgbwmxk76.py +799 -0
- progress/SpecForge/cache/compiled_kernels/2y/c2yhndikcsebqfmbw7l44gmcdoyw7ogaqt7quyeygz3mp5w6u6ke.py +715 -0
- progress/SpecForge/cache/compiled_kernels/2z/c2zdv5arszdl6ednyphqfnib6jwgzomr6zt6536b7gq75kp67uvh.py +1046 -0
- progress/SpecForge/cache/compiled_kernels/2z/c2zqq6qyjomc7iflknbqr7yjdhjux47hzv4nnsi5qfbeqglaip2h.py +707 -0
- progress/SpecForge/cache/compiled_kernels/32/8d96bbe05a966b7e7756831f09a79e31bf46fad0952af86f36d75557fc1735e8.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/32/c32pbcuz72bjfnkzvckfbbzlzuupc5yxl7t47b3qf74mmk5g2d2z.py +27 -0
- progress/SpecForge/cache/compiled_kernels/3b/a0a6b043ab548fdf71e72bbdf5daab7f72e9ed11a9ad9f8824a6263bb6bc5081.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/3b/c3bqw7dk7k6dcdrp3ycrthotye7y6zb26752jl4lwmfgaybpvr6y.py +27 -0
- progress/SpecForge/cache/compiled_kernels/3f/3f6057605b157d44fd56f748226a63975b79198f94871188e73e46cd6c7f8792.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/3f/c3fttv7enp2yvnla3r6jkk4galt2qdpxw577ghvkmmx6zqaqla74.py +54 -0
- progress/SpecForge/cache/compiled_kernels/3n/c3nlaqknekmjv2zuxzow4rf42v3gorxnfp6uod3dg3ic5ibp6yp3.py +715 -0
- progress/SpecForge/cache/compiled_kernels/3q/c3qbvcsx2w7qss2v3eocuadgz6t35joo33bflzqkxzzj747zcjpk.py +51 -0
- progress/SpecForge/cache/compiled_kernels/3q/fc5920467dd1501963c976e2b895fc37747fdebfa098fff912209055f3a31828.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/3r/c3rkwwyedldrjz6sidtx5huqcsdgpdpu4xndmm6h4e4boo6cbg2w.py +702 -0
- progress/SpecForge/cache/compiled_kernels/3z/c3zi2pt6zmbthc6ythgt5p4ednhp6m24gpscb2pt6adf6xojetua.py +799 -0
- progress/SpecForge/cache/compiled_kernels/3z/c3zilfzjywngbdehwphwkhzpt6qcv6jecvzdajl2d5hb73xe6yzw.py +582 -0
- progress/SpecForge/cache/compiled_kernels/4a/7887d45b1aa6124e232769adbe995f9cc2af0dd187cb9928540172d82c7b8631.best_config +1 -0
progress/SpecForge/.devcontainer/Dockerfile
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM lmsysorg/sglang:dev
|
| 2 |
+
|
| 3 |
+
# Create non-root user with specified UID and GID
|
| 4 |
+
# NOTE: Replace with your own UID and GID. This is a workaround from https://github.com/microsoft/vscode-remote-release/issues/49#issuecomment-489060908.
|
| 5 |
+
ARG HOST_UID=1003
|
| 6 |
+
ARG HOST_GID=1003
|
| 7 |
+
RUN groupadd -g $HOST_GID devuser && \
|
| 8 |
+
useradd -m -u $HOST_UID -g $HOST_GID -s /bin/zsh devuser
|
| 9 |
+
|
| 10 |
+
# Give devuser sudo access
|
| 11 |
+
RUN apt-get update && apt-get install -y sudo && \
|
| 12 |
+
echo "devuser ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/devuser && \
|
| 13 |
+
rm -rf /var/lib/apt/lists/* && \
|
| 14 |
+
apt-get clean
|
| 15 |
+
|
| 16 |
+
# Set up oh-my-zsh for devuser
|
| 17 |
+
RUN cp -r /root/.oh-my-zsh /home/devuser/.oh-my-zsh && \
|
| 18 |
+
cp /root/.zshrc /home/devuser/.zshrc && \
|
| 19 |
+
cp /root/.vimrc /home/devuser/.vimrc && \
|
| 20 |
+
cp /root/.tmux.conf /home/devuser/.tmux.conf && \
|
| 21 |
+
sed -i 's|/root/.oh-my-zsh|/home/devuser/.oh-my-zsh|g' /home/devuser/.zshrc && \
|
| 22 |
+
chown -R devuser:devuser /home/devuser/
|
| 23 |
+
|
| 24 |
+
# Set workspace directory and ownership
|
| 25 |
+
WORKDIR /sgl-workspace/sglang
|
| 26 |
+
RUN chown -R devuser:devuser /sgl-workspace
|
| 27 |
+
|
| 28 |
+
# Switch to devuser
|
| 29 |
+
USER devuser
|
| 30 |
+
|
| 31 |
+
# Install rust
|
| 32 |
+
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
progress/SpecForge/.devcontainer/devcontainer.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "sglang",
|
| 3 |
+
"build": {
|
| 4 |
+
"dockerfile": "Dockerfile"
|
| 5 |
+
},
|
| 6 |
+
"remoteUser": "devuser",
|
| 7 |
+
"customizations": {
|
| 8 |
+
"vscode": {
|
| 9 |
+
"extensions": [
|
| 10 |
+
// Python development
|
| 11 |
+
"ms-python.python",
|
| 12 |
+
"charliermarsh.ruff",
|
| 13 |
+
// Rust development
|
| 14 |
+
"rust-lang.rust-analyzer",
|
| 15 |
+
"tamasfe.even-better-toml"
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
},
|
| 19 |
+
"forwardPorts": [],
|
| 20 |
+
"runArgs": [
|
| 21 |
+
"--gpus",
|
| 22 |
+
"all"
|
| 23 |
+
],
|
| 24 |
+
// The two lines below ensures that your local changes in the sglang
|
| 25 |
+
// repo is automatically synced to the sglang pip package installed
|
| 26 |
+
// in the dev docker container. You can remove / comment out these
|
| 27 |
+
// two lines if you prefer to sync code changes manually.
|
| 28 |
+
"workspaceMount": "source=${localWorkspaceFolder},target=/sgl-workspace/specforge,type=bind",
|
| 29 |
+
"workspaceFolder": "/sgl-workspace/specforge"
|
| 30 |
+
}
|
progress/SpecForge/.github/CODEOWNERS
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.github @FrankLeeeee
|
| 2 |
+
/specforge/core @FrankLeeeee
|
| 3 |
+
/specforge/data @zyksir @sleepcoo @shuaills
|
| 4 |
+
/specforge/layers @FrankLeeeee @FlamingoPg @sleepcoo @shuaills
|
| 5 |
+
/specforge/modeling @FlamingoPg @sleepcoo @shuaills @FrankLeeeee
|
| 6 |
+
/tests @FrankLeeeee
|
| 7 |
+
/assets @FrankLeeeee @zhyncs
|
| 8 |
+
/examples @shuaills @sleepcoo @FlamingoPg
|
| 9 |
+
/configs @FrankLeeeee @FlamingoPg
|
| 10 |
+
/benchmarks @FrankLeeeee
|
| 11 |
+
/scripts @shuaills @sleepcoo @FlamingoPg
|
progress/SpecForge/.github/pull_request_template.md
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- Thank you for your contribution! We appreciate it. The following guidelines will help improve your pull request and facilitate feedback. If anything is unclear, don't hesitate to submit your pull request and ask the maintainers for assistance. -->
|
| 2 |
+
|
| 3 |
+
## Motivation
|
| 4 |
+
|
| 5 |
+
<!-- Explain the purpose of this PR and the goals it aims to achieve. -->
|
| 6 |
+
|
| 7 |
+
## Modifications
|
| 8 |
+
|
| 9 |
+
<!-- Describe the changes made in this PR. -->
|
| 10 |
+
|
| 11 |
+
## Related Issues
|
| 12 |
+
|
| 13 |
+
<!-- Link to any related issues here. e.g. "Fixes #123" or "Closes #456" -->
|
| 14 |
+
|
| 15 |
+
## Accuracy Test
|
| 16 |
+
|
| 17 |
+
<!-- If this PR affects model-side code (e.g., kernels, model architecture), please provide accuracy test results. Ref: https://docs.sglang.ai/references/accuracy_evaluation.html -->
|
| 18 |
+
|
| 19 |
+
## Benchmark & Profiling
|
| 20 |
+
|
| 21 |
+
<!-- If this PR is expected to impact performance, please provide benchmark and profiling results. Ref: https://docs.sglang.ai/references/benchmark_and_profiling.html -->
|
| 22 |
+
|
| 23 |
+
## Checklist
|
| 24 |
+
|
| 25 |
+
- [ ] Format your code according to the [Code Formatting with Pre-Commit](https://docs.sglang.ai/references/contribution_guide.html#code-formatting-with-pre-commit).
|
| 26 |
+
- [ ] Add unit tests as outlined in the [Running Unit Tests](https://docs.sglang.ai/references/contribution_guide.html#running-unit-tests-adding-to-ci).
|
| 27 |
+
- [ ] Update documentation / docstrings / example tutorials as needed, according to [Writing Documentation](https://docs.sglang.ai/references/contribution_guide.html#writing-documentation-running-docs-ci).
|
| 28 |
+
- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to [Benchmark and Profiling](https://docs.sglang.ai/references/benchmark_and_profiling.html) and [Accuracy Results](https://docs.sglang.ai/references/accuracy_evaluation.html).
|
| 29 |
+
- [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
|
| 30 |
+
- [ ] Please feel free to join our Slack channel at https://sgl-fru7574.slack.com/archives/C09784E3EN6 to discuss your PR.
|
progress/SpecForge/assets/logo.svg
ADDED
|
|
progress/SpecForge/benchmarks/README.md
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Benchmarking for Speculative Decoding
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
We provided a unified script to test the performance of the Speculative Decoding with EAGLE3 algorithm on multiple datasets. You can follow the steps below to run the benchmarks.
|
| 6 |
+
|
| 7 |
+
## Run Benchmarks
|
| 8 |
+
|
| 9 |
+
### Launch SGLang and Benchmarker Concurrently
|
| 10 |
+
|
| 11 |
+
`bench_eagle3.py` can help you launch a SGLang server process and a Benchmarking process concurrently. In this way, you don't have to launch the SGLang server manually, this script will manually handle the SGLang launch under different speculative decoding configurations. Some important arguments are:
|
| 12 |
+
- `--model-path`: the path to the target model.
|
| 13 |
+
- `--speculative-draft-model-path`: the path to the draft model.
|
| 14 |
+
- `--port`: the port to launch the SGLang server.
|
| 15 |
+
- `--trust-remote-code`: trust the remote code.
|
| 16 |
+
- `--mem-fraction-static`: the memory fraction for the static memory.
|
| 17 |
+
- `--tp-size`: the tensor parallelism size.
|
| 18 |
+
- `--attention-backend`: the attention backend.
|
| 19 |
+
- `--config-list`: the list of speculative decoding configuration to test, the format is `<batch-size>,<num-steps>,<topk>,<num-draft-tokens>`.
|
| 20 |
+
- `--benchmark-list`: the list of benchmarks to test, the format is `<benchmark-name>:<num-prompts>:<subset>`.
|
| 21 |
+
|
| 22 |
+
```shell
|
| 23 |
+
python3 bench_eagle3.py \
|
| 24 |
+
--model-path meta-llama/Llama-3.1-8B-Instruct \
|
| 25 |
+
--speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
|
| 26 |
+
--port 30000 \
|
| 27 |
+
--trust-remote-code \
|
| 28 |
+
--mem-fraction-static 0.8 \
|
| 29 |
+
--tp-size 1 \
|
| 30 |
+
--attention-backend fa3 \
|
| 31 |
+
--config-list 1,0,0,0 1,3,1,4 \
|
| 32 |
+
--benchmark-list mtbench gsm8k:5 ceval:5:accountant \
|
| 33 |
+
--dtype bfloat16
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### Launch Benchmarker Independently
|
| 37 |
+
|
| 38 |
+
If you want to launch the SGLang server independently, you can use the following command.
|
| 39 |
+
|
| 40 |
+
```shell
|
| 41 |
+
# you can launch a server
|
| 42 |
+
python3 -m sglang.launch_server \
|
| 43 |
+
--model meta-llama/Llama-3.1-8B-Instruct \
|
| 44 |
+
--speculative-algorithm EAGLE3 \
|
| 45 |
+
--speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
|
| 46 |
+
--speculative-num-steps 3 \
|
| 47 |
+
--speculative-eagle-topk 1 \
|
| 48 |
+
--speculative-num-draft-tokens 4 \
|
| 49 |
+
--mem-fraction-static 0.75 \
|
| 50 |
+
--cuda-graph-max-bs 1 \
|
| 51 |
+
--tp 1 \
|
| 52 |
+
--trust-remote-code \
|
| 53 |
+
--host 0.0.0.0 \
|
| 54 |
+
--port 30000 \
|
| 55 |
+
--dtype bfloat16
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
Then we can start benchmarking. Note that you should use the same host and port as the one used in the SGLang server. Note that `--skip-launch-server` is required to skip the launch of the SGLang server.
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
python bench_eagle3.py \
|
| 62 |
+
--model-path meta-llama/Llama-3.1-8B-Instruct \
|
| 63 |
+
--port 30000 \
|
| 64 |
+
--config-list 1,3,1,4 \
|
| 65 |
+
--benchmark-list mtbench:5 ceval:5:accountant gsm8k:5 humaneval:5 math500:5 mtbench:5 aime:1 \
|
| 66 |
+
--skip-launch-server
|
| 67 |
+
```
|
progress/SpecForge/benchmarks/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Benchmark scripts for speculative decoding evaluation.
|
| 3 |
+
"""
|
progress/SpecForge/benchmarks/bench_eagle3.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Usage:
|
| 4 |
+
|
| 5 |
+
# if you want to run benchmarks directly
|
| 6 |
+
# mtbench:20 means only run 20 samples in the dataset
|
| 7 |
+
python bench_eagle3.py \
|
| 8 |
+
--model meta-llama/Llama-3.1-8B-Instruct \
|
| 9 |
+
--speculative-algorithm EAGLE3 \
|
| 10 |
+
--speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
|
| 11 |
+
--port 30000 \
|
| 12 |
+
--config-list 1,0,0,0 1,3,1,4 \
|
| 13 |
+
--benchmark-list mtbench:20 \
|
| 14 |
+
--dtype bfloat16
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
or if you want run sglang alone.
|
| 18 |
+
|
| 19 |
+
# launch sglang
|
| 20 |
+
python3 -m sglang.launch_server \
|
| 21 |
+
--model meta-llama/Llama-3.1-8B-Instruct \
|
| 22 |
+
--speculative-algorithm EAGLE3 \
|
| 23 |
+
--speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
|
| 24 |
+
--speculative-num-steps 3 \
|
| 25 |
+
--speculative-eagle-topk 1 \
|
| 26 |
+
--speculative-num-draft-tokens 4 \
|
| 27 |
+
--mem-fraction-static 0.75 \
|
| 28 |
+
--cuda-graph-max-bs 1 \
|
| 29 |
+
--tp 1 \
|
| 30 |
+
--trust-remote-code \
|
| 31 |
+
--host 0.0.0.0 \
|
| 32 |
+
--port 30000 \
|
| 33 |
+
--dtype bfloat16
|
| 34 |
+
|
| 35 |
+
# then run benchmarks
|
| 36 |
+
python bench_eagle3.py \
|
| 37 |
+
--model-path meta-llama/Llama-3.1-8B-Instruct \
|
| 38 |
+
--port 30000 \
|
| 39 |
+
--config-list 1,0,0,0 \
|
| 40 |
+
--benchmark-list mtbench:80 \
|
| 41 |
+
--dtype bfloat16 \
|
| 42 |
+
--skip-launch-server
|
| 43 |
+
"""
|
| 44 |
+
import argparse
|
| 45 |
+
import json
|
| 46 |
+
import os
|
| 47 |
+
import time
|
| 48 |
+
from dataclasses import asdict
|
| 49 |
+
from typing import List
|
| 50 |
+
|
| 51 |
+
import requests
|
| 52 |
+
from benchmarker import BENCHMARKS
|
| 53 |
+
from sglang.srt.server_args import ServerArgs
|
| 54 |
+
from sglang.test.test_utils import kill_process_tree, popen_launch_server
|
| 55 |
+
from sglang.utils import wait_for_server
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def parse_args():
|
| 59 |
+
parser = argparse.ArgumentParser()
|
| 60 |
+
sglang_group = parser.add_argument_group("sglang")
|
| 61 |
+
ServerArgs.add_cli_args(sglang_group)
|
| 62 |
+
|
| 63 |
+
# make the follow args a group
|
| 64 |
+
benchmark_group = parser.add_argument_group("benchmark")
|
| 65 |
+
benchmark_group.add_argument(
|
| 66 |
+
"--skip-launch-server", action="store_true", default=False
|
| 67 |
+
)
|
| 68 |
+
benchmark_group.add_argument("--timeout-for-server-launch", type=int, default=600)
|
| 69 |
+
benchmark_group.add_argument("--num-prompts", type=int, default=80)
|
| 70 |
+
benchmark_group.add_argument("--output-dir", type=str, default="./results")
|
| 71 |
+
benchmark_group.add_argument(
|
| 72 |
+
"--config-list", type=str, nargs="+", default=["1,0,0,0", "1,3,1,4"]
|
| 73 |
+
)
|
| 74 |
+
benchmark_group.add_argument(
|
| 75 |
+
"--name",
|
| 76 |
+
type=str,
|
| 77 |
+
default=None,
|
| 78 |
+
help="name of this benchmark run, if provided, will be added to the output file name",
|
| 79 |
+
)
|
| 80 |
+
benchmark_group.add_argument(
|
| 81 |
+
"--benchmark-list",
|
| 82 |
+
type=str,
|
| 83 |
+
nargs="+",
|
| 84 |
+
default=[
|
| 85 |
+
"mtbench:80",
|
| 86 |
+
"gsm8k:200",
|
| 87 |
+
"humaneval:200",
|
| 88 |
+
"math500:200",
|
| 89 |
+
"ceval:200",
|
| 90 |
+
],
|
| 91 |
+
help=f"The list of benchmarks to run. The format is <benchmark-name>:<num-prompts>:<subset>,<subset>. We support the following benchmarks: {', '.join(BENCHMARKS.benchmarks.keys())}",
|
| 92 |
+
)
|
| 93 |
+
benchmark_group.add_argument(
|
| 94 |
+
"--enable-multi-turn-conversation",
|
| 95 |
+
action="store_true",
|
| 96 |
+
default=False,
|
| 97 |
+
)
|
| 98 |
+
return parser.parse_args()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def launch_sglang_server(
|
| 102 |
+
server_args: ServerArgs,
|
| 103 |
+
base_url: str,
|
| 104 |
+
batch_size: int,
|
| 105 |
+
steps: int,
|
| 106 |
+
topk: int,
|
| 107 |
+
num_draft_tokens: int,
|
| 108 |
+
timeout: int,
|
| 109 |
+
):
|
| 110 |
+
"""
|
| 111 |
+
This function launches the SGLang server with the given server arguments.
|
| 112 |
+
"""
|
| 113 |
+
sglang_args: List[str] = []
|
| 114 |
+
if steps > 0:
|
| 115 |
+
sglang_args.extend(
|
| 116 |
+
[
|
| 117 |
+
"--speculative-algorithm",
|
| 118 |
+
"EAGLE3",
|
| 119 |
+
"--speculative-num-steps",
|
| 120 |
+
str(steps),
|
| 121 |
+
"--speculative-eagle-topk",
|
| 122 |
+
str(topk),
|
| 123 |
+
"--speculative-num-draft-tokens",
|
| 124 |
+
str(num_draft_tokens),
|
| 125 |
+
"--speculative-draft-model-path",
|
| 126 |
+
server_args.speculative_draft_model_path,
|
| 127 |
+
]
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
sglang_args.extend(
|
| 131 |
+
[
|
| 132 |
+
"--cuda-graph-max-bs",
|
| 133 |
+
str(batch_size),
|
| 134 |
+
"--mem-fraction-static",
|
| 135 |
+
str(server_args.mem_fraction_static),
|
| 136 |
+
"--tp-size",
|
| 137 |
+
str(server_args.tp_size),
|
| 138 |
+
"--max-running-requests",
|
| 139 |
+
str(batch_size),
|
| 140 |
+
]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if server_args.trust_remote_code:
|
| 144 |
+
sglang_args.extend(["--trust-remote-code"])
|
| 145 |
+
|
| 146 |
+
if server_args.disable_radix_cache:
|
| 147 |
+
sglang_args.extend(["--disable-radix-cache"])
|
| 148 |
+
|
| 149 |
+
if server_args.ep_size:
|
| 150 |
+
sglang_args.extend(["--ep-size", str(server_args.ep_size)])
|
| 151 |
+
|
| 152 |
+
if server_args.attention_backend:
|
| 153 |
+
sglang_args.extend(["--attention-backend", server_args.attention_backend])
|
| 154 |
+
|
| 155 |
+
if server_args.quantization:
|
| 156 |
+
sglang_args.extend(["--quantization", server_args.quantization])
|
| 157 |
+
|
| 158 |
+
if server_args.dtype:
|
| 159 |
+
sglang_args.extend(["--dtype", server_args.dtype])
|
| 160 |
+
|
| 161 |
+
process = popen_launch_server(
|
| 162 |
+
server_args.model_path,
|
| 163 |
+
base_url,
|
| 164 |
+
timeout=timeout,
|
| 165 |
+
other_args=sglang_args,
|
| 166 |
+
env={
|
| 167 |
+
"SGLANG_RECORD_STEP_TIME": "1",
|
| 168 |
+
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN": "1",
|
| 169 |
+
**os.environ,
|
| 170 |
+
},
|
| 171 |
+
)
|
| 172 |
+
return process
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def send_flush_cache_request(base_url: str):
|
| 176 |
+
requests.post(base_url + "/flush_cache")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def main():
|
| 180 |
+
args = parse_args()
|
| 181 |
+
server_args: ServerArgs = ServerArgs.from_cli_args(args)
|
| 182 |
+
configs = [tuple(map(int, config.split(","))) for config in args.config_list]
|
| 183 |
+
|
| 184 |
+
# split the arg into list of (bench_name, num_prompts)
|
| 185 |
+
benchmark_list = []
|
| 186 |
+
for item in args.benchmark_list:
|
| 187 |
+
splits = item.split(":")
|
| 188 |
+
if len(splits) == 1:
|
| 189 |
+
bench_name = splits[0]
|
| 190 |
+
num_prompts = None
|
| 191 |
+
subset = None
|
| 192 |
+
elif len(splits) == 2:
|
| 193 |
+
bench_name, num_prompts = splits
|
| 194 |
+
subset = None
|
| 195 |
+
elif len(splits) == 3:
|
| 196 |
+
bench_name, num_prompts, subset = splits
|
| 197 |
+
subset = subset.split(",")
|
| 198 |
+
else:
|
| 199 |
+
raise ValueError(f"Invalid benchmark list format: {item}")
|
| 200 |
+
benchmark_list.append((bench_name, num_prompts, subset))
|
| 201 |
+
assert len(benchmark_list) != 0, "the number of benchmark list is 0"
|
| 202 |
+
|
| 203 |
+
base_url = f"http://localhost:{args.port}"
|
| 204 |
+
|
| 205 |
+
results = {}
|
| 206 |
+
results["model"] = server_args.speculative_draft_model_path
|
| 207 |
+
|
| 208 |
+
def run_benchmarks(batch_size: int, steps: int, topk: int, num_draft_tokens: int):
|
| 209 |
+
for benchmark_name, num_prompts, subset in benchmark_list:
|
| 210 |
+
print(
|
| 211 |
+
f"Running benchmark {benchmark_name} with {num_prompts} prompts, batch size {batch_size}, steps {steps}, topk {topk}, num_draft_tokens {num_draft_tokens}, subset {subset}"
|
| 212 |
+
)
|
| 213 |
+
benchmarkder_cls = BENCHMARKS.get(benchmark_name)
|
| 214 |
+
num_prompts = int(num_prompts) if num_prompts is not None else None
|
| 215 |
+
if subset is None:
|
| 216 |
+
benchmarker = benchmarkder_cls(num_samples=num_prompts)
|
| 217 |
+
else:
|
| 218 |
+
benchmarker = benchmarkder_cls(num_samples=num_prompts, subset=subset)
|
| 219 |
+
metrics_list = benchmarker.run(
|
| 220 |
+
host=args.host, port=args.port, batch_size=batch_size
|
| 221 |
+
)
|
| 222 |
+
send_flush_cache_request(base_url)
|
| 223 |
+
if benchmark_name not in results:
|
| 224 |
+
results[benchmark_name] = []
|
| 225 |
+
results[benchmark_name].append(
|
| 226 |
+
dict(
|
| 227 |
+
batch_size=batch_size,
|
| 228 |
+
steps=steps,
|
| 229 |
+
topk=topk,
|
| 230 |
+
num_draft_tokens=num_draft_tokens,
|
| 231 |
+
metrics=[asdict(metric) for metric in metrics_list],
|
| 232 |
+
num_samples=num_prompts,
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if args.skip_launch_server:
|
| 237 |
+
batch_size = configs[0][0] if len(configs) > 0 else 8
|
| 238 |
+
run_benchmarks(batch_size, None, None, None)
|
| 239 |
+
else:
|
| 240 |
+
# we itearate over each config from args
|
| 241 |
+
for batch_size, steps, topk, num_draft_tokens in configs:
|
| 242 |
+
process = launch_sglang_server(
|
| 243 |
+
server_args,
|
| 244 |
+
base_url,
|
| 245 |
+
batch_size,
|
| 246 |
+
steps,
|
| 247 |
+
topk,
|
| 248 |
+
num_draft_tokens,
|
| 249 |
+
args.timeout_for_server_launch,
|
| 250 |
+
)
|
| 251 |
+
wait_for_server(base_url)
|
| 252 |
+
run_benchmarks(batch_size, steps, topk, num_draft_tokens)
|
| 253 |
+
kill_process_tree(process.pid)
|
| 254 |
+
process.wait()
|
| 255 |
+
|
| 256 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 257 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 258 |
+
result_file = os.path.join(
|
| 259 |
+
args.output_dir,
|
| 260 |
+
f"{args.name + '_' if args.name else ''}results_{timestamp}.jsonl",
|
| 261 |
+
)
|
| 262 |
+
with open(result_file, "w") as f:
|
| 263 |
+
json.dump(results, f, indent=4)
|
| 264 |
+
print(f"Results saved to {result_file}")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
if __name__ == "__main__":
|
| 268 |
+
main()
|
progress/SpecForge/benchmarks/benchmarker/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .aime import AIMEBenchmarker
|
| 2 |
+
from .ceval import CEvalBenchmarker
|
| 3 |
+
from .financeqa import FinanceQABenchmarker
|
| 4 |
+
from .gpqa import GPQABenchmarker
|
| 5 |
+
from .gsm8k import GSM8KBenchmarker
|
| 6 |
+
from .humaneval import HumanEvalBenchmarker
|
| 7 |
+
from .livecodebench import LCBBenchmarker
|
| 8 |
+
from .math500 import Math500Benchmarker
|
| 9 |
+
from .mmlu import MMLUBenchmarker
|
| 10 |
+
from .mmstar import MMStarBenchmarker
|
| 11 |
+
from .mtbench import MTBenchBenchmarker
|
| 12 |
+
from .registry import BENCHMARKS
|
| 13 |
+
from .simpleqa import SimpleQABenchmarker
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"BENCHMARKS",
|
| 17 |
+
"AIMEBenchmarker",
|
| 18 |
+
"CEvalBenchmarker",
|
| 19 |
+
"GSM8KBenchmarker",
|
| 20 |
+
"HumanEvalBenchmarker",
|
| 21 |
+
"Math500Benchmarker",
|
| 22 |
+
"MTBenchBenchmarker",
|
| 23 |
+
"MMStarBenchmarker",
|
| 24 |
+
"GPQABenchmarker",
|
| 25 |
+
"FinanceQABenchmarker",
|
| 26 |
+
"MMLUBenchmarker",
|
| 27 |
+
"LCBBenchmarker",
|
| 28 |
+
"SimpleQABenchmarker",
|
| 29 |
+
]
|
progress/SpecForge/benchmarks/benchmarker/aime.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AIME benchmark
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
from .base import Benchmarker
|
| 11 |
+
from .registry import BENCHMARKS
|
| 12 |
+
from .utils import create_simple_sgl_function
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def extract_aime_answer(output: str) -> Optional[str]:
|
| 16 |
+
"""Extract final answer from AIME problem solution.
|
| 17 |
+
|
| 18 |
+
AIME answers are typically integers between 0 and 999, and are usually
|
| 19 |
+
in \boxed{} format.
|
| 20 |
+
"""
|
| 21 |
+
# Try to find answer in \boxed{} format
|
| 22 |
+
boxed_pattern = r"\\boxed\{([^}]+)\}"
|
| 23 |
+
match = re.search(boxed_pattern, output)
|
| 24 |
+
if match:
|
| 25 |
+
answer = match.group(1).strip()
|
| 26 |
+
# Extract number from the boxed content
|
| 27 |
+
numbers = re.findall(r"\d+", answer)
|
| 28 |
+
if numbers:
|
| 29 |
+
return numbers[-1] # Take the last number (usually the final answer)
|
| 30 |
+
return answer
|
| 31 |
+
|
| 32 |
+
# Try to find answer in \boxed format (without braces)
|
| 33 |
+
boxed_pattern2 = r"\\boxed\s+(\d+)"
|
| 34 |
+
match = re.search(boxed_pattern2, output)
|
| 35 |
+
if match:
|
| 36 |
+
return match.group(1).strip()
|
| 37 |
+
|
| 38 |
+
# Look for patterns like "The answer is 42" or "Answer: 123"
|
| 39 |
+
answer_patterns = [
|
| 40 |
+
r"(?:answer|Answer|ANSWER)[\s:]+(\d+)",
|
| 41 |
+
r"(?:final\s+answer|Final\s+Answer)[\s:]+(\d+)",
|
| 42 |
+
r"(?:is|equals?|=\s*)(\d+)\s*$",
|
| 43 |
+
]
|
| 44 |
+
for pattern in answer_patterns:
|
| 45 |
+
matches = re.findall(pattern, output, re.IGNORECASE)
|
| 46 |
+
if matches:
|
| 47 |
+
return matches[-1].strip()
|
| 48 |
+
|
| 49 |
+
# Fallback: extract the last integer in the text
|
| 50 |
+
numbers = re.findall(r"\b(\d+)\b", output)
|
| 51 |
+
if numbers:
|
| 52 |
+
# Filter to reasonable AIME answer range (0-999)
|
| 53 |
+
valid_numbers = [n for n in numbers if 0 <= int(n) <= 999]
|
| 54 |
+
if valid_numbers:
|
| 55 |
+
return valid_numbers[-1]
|
| 56 |
+
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@BENCHMARKS.register("aime")
|
| 61 |
+
class AIMEBenchmarker(Benchmarker):
|
| 62 |
+
"""AIME benchmark implementation."""
|
| 63 |
+
|
| 64 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 65 |
+
super().__init__(num_samples, None)
|
| 66 |
+
|
| 67 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
|
| 68 |
+
"""Load and preprocess AIME dataset."""
|
| 69 |
+
dataset = load_dataset("Maxwell-Jia/AIME_2024")["train"]
|
| 70 |
+
questions = []
|
| 71 |
+
labels = []
|
| 72 |
+
for idx, q in enumerate(dataset):
|
| 73 |
+
if self.num_samples is not None and idx >= self.num_samples:
|
| 74 |
+
break
|
| 75 |
+
|
| 76 |
+
questions.append({"question": q["Problem"]})
|
| 77 |
+
# Extract answer from Answer field
|
| 78 |
+
answer = None
|
| 79 |
+
if "Answer" in q:
|
| 80 |
+
answer = str(q["Answer"]).strip()
|
| 81 |
+
elif "answer" in q:
|
| 82 |
+
answer = str(q["answer"]).strip()
|
| 83 |
+
labels.append(answer)
|
| 84 |
+
return questions, labels
|
| 85 |
+
|
| 86 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
|
| 87 |
+
"""Extract answer from model output."""
|
| 88 |
+
return extract_aime_answer(output)
|
| 89 |
+
|
| 90 |
+
def compute_accuracy(
|
| 91 |
+
self, predictions: List[Any], labels: List[Any]
|
| 92 |
+
) -> Optional[float]:
|
| 93 |
+
"""Compute accuracy for AIME by comparing numeric answers."""
|
| 94 |
+
if not labels or len(labels) == 0:
|
| 95 |
+
return None
|
| 96 |
+
if all(label is None for label in labels):
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
correct = 0
|
| 100 |
+
valid_count = 0
|
| 101 |
+
for pred, label in zip(predictions, labels):
|
| 102 |
+
if label is not None:
|
| 103 |
+
valid_count += 1
|
| 104 |
+
if pred is not None:
|
| 105 |
+
# Normalize answers for comparison
|
| 106 |
+
pred_normalized = str(pred).strip()
|
| 107 |
+
label_normalized = str(label).strip()
|
| 108 |
+
# Try exact match first
|
| 109 |
+
if pred_normalized == label_normalized:
|
| 110 |
+
correct += 1
|
| 111 |
+
else:
|
| 112 |
+
# Try numeric comparison
|
| 113 |
+
try:
|
| 114 |
+
pred_num = int(pred_normalized)
|
| 115 |
+
label_num = int(label_normalized)
|
| 116 |
+
if pred_num == label_num:
|
| 117 |
+
correct += 1
|
| 118 |
+
except ValueError:
|
| 119 |
+
pass
|
| 120 |
+
|
| 121 |
+
return correct / valid_count if valid_count > 0 else 0.0
|
| 122 |
+
|
| 123 |
+
def create_sgl_function(self):
|
| 124 |
+
"""Create SGL function for AIME with reasoning prompt."""
|
| 125 |
+
return create_simple_sgl_function(
|
| 126 |
+
function_name="reasoning_gen",
|
| 127 |
+
answer_key="answer",
|
| 128 |
+
user_prefix="\nPlease reason step by step, and put your final answer within \\boxed{}.",
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def get_max_new_tokens(self) -> int:
|
| 132 |
+
"""AIME problems require more tokens."""
|
| 133 |
+
return 32768
|
progress/SpecForge/benchmarks/benchmarker/base.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base class for benchmark implementations.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
from abc import ABC, abstractmethod
|
| 7 |
+
from argparse import Namespace
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
from sglang import set_default_backend
|
| 11 |
+
from sglang.test.test_utils import select_sglang_backend
|
| 12 |
+
|
| 13 |
+
from .utils import compute_metrics
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Benchmarker(ABC):
|
| 17 |
+
"""
|
| 18 |
+
Base class for benchmark implementations.
|
| 19 |
+
|
| 20 |
+
Subclasses should implement:
|
| 21 |
+
- load_data(): Load and preprocess dataset
|
| 22 |
+
- create_sgl_function(): Create the SGL function for inference
|
| 23 |
+
|
| 24 |
+
Optional overrides:
|
| 25 |
+
- extract_answer(): Extract answer from model output (if needed)
|
| 26 |
+
- compute_accuracy(): Compute accuracy metric (if applicable)
|
| 27 |
+
- get_answer_keys(): Get list of answer keys for multi-turn conversations
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
num_samples: The number of samples to run the benchmark on. If not provided, all questions will be used.
|
| 31 |
+
subset: The subset of the dataset to run the benchmark on. If not provided, all subsets will be used.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
|
| 36 |
+
):
|
| 37 |
+
self.num_samples = num_samples
|
| 38 |
+
self.subset = subset
|
| 39 |
+
|
| 40 |
+
@abstractmethod
|
| 41 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Any]]:
|
| 42 |
+
"""
|
| 43 |
+
Load and preprocess the dataset.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Tuple of (questions, labels) where:
|
| 47 |
+
- questions: List of question dicts for SGL function
|
| 48 |
+
- labels: List of ground truth labels (can be None if not applicable)
|
| 49 |
+
"""
|
| 50 |
+
raise NotImplementedError
|
| 51 |
+
|
| 52 |
+
@abstractmethod
|
| 53 |
+
def create_sgl_function(self) -> Callable:
|
| 54 |
+
"""
|
| 55 |
+
Create the SGL function for inference.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
SGL function decorated with @sgl.function
|
| 59 |
+
"""
|
| 60 |
+
raise NotImplementedError
|
| 61 |
+
|
| 62 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[Any]:
|
| 63 |
+
"""
|
| 64 |
+
Extract answer from model output.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
output: Raw model output string
|
| 68 |
+
label: Optional ground truth label for reference
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Extracted answer, or None if extraction fails
|
| 72 |
+
"""
|
| 73 |
+
return output
|
| 74 |
+
|
| 75 |
+
def compute_accuracy(
|
| 76 |
+
self, predictions: List[Any], labels: List[Any]
|
| 77 |
+
) -> Optional[float]:
|
| 78 |
+
"""
|
| 79 |
+
Compute accuracy metric.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
predictions: List of predicted answers
|
| 83 |
+
labels: List of ground truth labels
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Accuracy score (0-1), or None if not applicable
|
| 87 |
+
"""
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
def get_answer_keys(self) -> Optional[List[str]]:
|
| 91 |
+
"""
|
| 92 |
+
Get list of answer keys for multi-turn conversations.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
List of answer keys (e.g., ["answer_1", "answer_2"]), or None for single-turn
|
| 96 |
+
"""
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
def get_max_new_tokens(self) -> int:
|
| 100 |
+
"""
|
| 101 |
+
Get maximum number of new tokens to generate.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Maximum tokens (default: 2048)
|
| 105 |
+
"""
|
| 106 |
+
return 2048
|
| 107 |
+
|
| 108 |
+
def run(
|
| 109 |
+
self,
|
| 110 |
+
host: str,
|
| 111 |
+
port: int,
|
| 112 |
+
batch_size: int,
|
| 113 |
+
max_new_tokens: int = None,
|
| 114 |
+
num_runs: int = 1,
|
| 115 |
+
):
|
| 116 |
+
"""
|
| 117 |
+
Run the benchmark evaluation.
|
| 118 |
+
|
| 119 |
+
This method handles the common workflow:
|
| 120 |
+
1. Initialize backend
|
| 121 |
+
2. Load data
|
| 122 |
+
3. Create SGL function
|
| 123 |
+
4. Run inference loops
|
| 124 |
+
5. Compute metrics
|
| 125 |
+
6. Print results
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
host (str): The host of the SGLang server
|
| 129 |
+
port (int): The port of the SGLang server
|
| 130 |
+
batch_size (int): The number of prompts to process in parallel
|
| 131 |
+
num_samples (int): The number of samples to run the benchmark on. If not provided, all samples will be used.
|
| 132 |
+
max_new_tokens (int): Maximum number of new tokens to generate, default is 2048
|
| 133 |
+
num_runs (int): The number of times to run this benchmark, default is 1. You can set it to a larger number if you want to get more stable results.
|
| 134 |
+
"""
|
| 135 |
+
if not host.startswith(("http://", "https://")):
|
| 136 |
+
host = f"http://{host}"
|
| 137 |
+
# Initialize backend
|
| 138 |
+
sglang_args = Namespace(host=host, port=port, backend="srt-no-parallel")
|
| 139 |
+
set_default_backend(select_sglang_backend(sglang_args))
|
| 140 |
+
|
| 141 |
+
# Load data
|
| 142 |
+
questions, labels = self.load_data()
|
| 143 |
+
if len(questions) == 0:
|
| 144 |
+
print("No valid questions found. Please check the dataset format.")
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
# Create SGL function
|
| 148 |
+
sgl_function = self.create_sgl_function()
|
| 149 |
+
|
| 150 |
+
# Run evaluation loops
|
| 151 |
+
metrics_list = []
|
| 152 |
+
answer_keys = self.get_answer_keys()
|
| 153 |
+
max_new_tokens = max_new_tokens or self.get_max_new_tokens()
|
| 154 |
+
|
| 155 |
+
for _ in range(num_runs):
|
| 156 |
+
tic = time.perf_counter()
|
| 157 |
+
states = sgl_function.run_batch(
|
| 158 |
+
questions,
|
| 159 |
+
temperature=0,
|
| 160 |
+
max_new_tokens=max_new_tokens,
|
| 161 |
+
num_threads=batch_size,
|
| 162 |
+
progress_bar=True,
|
| 163 |
+
)
|
| 164 |
+
latency = time.perf_counter() - tic
|
| 165 |
+
|
| 166 |
+
# Extract predictions
|
| 167 |
+
predictions = []
|
| 168 |
+
primary_answer_key = answer_keys[0] if answer_keys else "answer"
|
| 169 |
+
for i in range(len(states)):
|
| 170 |
+
# Access answer from state object (states[i] supports dict-like access)
|
| 171 |
+
output = states[i][primary_answer_key]
|
| 172 |
+
if isinstance(output, str):
|
| 173 |
+
extracted = self.extract_answer(
|
| 174 |
+
output,
|
| 175 |
+
(labels[i] if labels and i < len(labels) else None),
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
extracted = output
|
| 179 |
+
predictions.append(extracted)
|
| 180 |
+
|
| 181 |
+
# Compute accuracy if applicable
|
| 182 |
+
accuracy = None
|
| 183 |
+
# Check if we have a labels list (even if all labels are None)
|
| 184 |
+
has_labels_list = labels and len(labels) > 0
|
| 185 |
+
|
| 186 |
+
if has_labels_list:
|
| 187 |
+
# Always call compute_accuracy if we have a labels list
|
| 188 |
+
# This allows it to return None, which will be displayed in print_results
|
| 189 |
+
accuracy = self.compute_accuracy(predictions, labels)
|
| 190 |
+
if accuracy is not None:
|
| 191 |
+
valid_count = sum(1 for p in predictions if p is not None)
|
| 192 |
+
if valid_count < len(predictions):
|
| 193 |
+
print(
|
| 194 |
+
f"Warning: {len(predictions) - valid_count} predictions could not be extracted."
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Compute performance metrics
|
| 198 |
+
metrics = compute_metrics(
|
| 199 |
+
states,
|
| 200 |
+
latency,
|
| 201 |
+
answer_key=primary_answer_key,
|
| 202 |
+
additional_answer_keys=(
|
| 203 |
+
answer_keys[1:] if answer_keys and len(answer_keys) > 1 else None
|
| 204 |
+
),
|
| 205 |
+
)
|
| 206 |
+
# Always set accuracy if we have a labels list (even if compute_accuracy returns None)
|
| 207 |
+
# This allows print_results to show None when compute_accuracy returns None
|
| 208 |
+
if has_labels_list:
|
| 209 |
+
metrics.accuracy = (
|
| 210 |
+
accuracy # Can be None if compute_accuracy returns None
|
| 211 |
+
)
|
| 212 |
+
if accuracy is not None:
|
| 213 |
+
metrics.num_valid_predictions = sum(
|
| 214 |
+
1 for p in predictions if p is not None
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
metrics_list.append(metrics)
|
| 218 |
+
return metrics_list
|
progress/SpecForge/benchmarks/benchmarker/ceval.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
C-Eval benchmark evaluation script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from datasets import concatenate_datasets, load_dataset
|
| 9 |
+
|
| 10 |
+
from .base import Benchmarker
|
| 11 |
+
from .registry import BENCHMARKS
|
| 12 |
+
from .utils import create_simple_sgl_function
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def extract_answer(answer_str: str) -> str:
|
| 16 |
+
"""Extract the answer choice (A, B, C, D) from the model output."""
|
| 17 |
+
# Try to find the answer in various formats
|
| 18 |
+
answer_str = answer_str.strip().upper()
|
| 19 |
+
|
| 20 |
+
# Direct match for single letter
|
| 21 |
+
match = re.search(r"\b([ABCD])\b", answer_str)
|
| 22 |
+
if match:
|
| 23 |
+
return match.group(1)
|
| 24 |
+
|
| 25 |
+
# Try to find answer in parentheses or brackets
|
| 26 |
+
for pattern in [
|
| 27 |
+
r"\(([ABCD])\)",
|
| 28 |
+
r"\[([ABCD])\]",
|
| 29 |
+
r"答案[::]\s*([ABCD])",
|
| 30 |
+
r"Answer[::]\s*([ABCD])",
|
| 31 |
+
]:
|
| 32 |
+
match = re.search(pattern, answer_str, re.IGNORECASE)
|
| 33 |
+
if match:
|
| 34 |
+
return match.group(1).upper()
|
| 35 |
+
|
| 36 |
+
# Try to find the first occurrence of A, B, C, or D
|
| 37 |
+
match = re.search(r"([ABCD])", answer_str)
|
| 38 |
+
if match:
|
| 39 |
+
return match.group(1)
|
| 40 |
+
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def format_question(question: str, options: List[str]) -> str:
|
| 45 |
+
"""Format the question with options."""
|
| 46 |
+
prompt = question + "\n\n选项:\n"
|
| 47 |
+
for i, option in enumerate(options):
|
| 48 |
+
prompt += f"{chr(65 + i)}. {option}\n"
|
| 49 |
+
prompt += "\n请从A、B、C、D中选择一个答案。"
|
| 50 |
+
return prompt
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@BENCHMARKS.register("ceval")
|
| 54 |
+
class CEvalBenchmarker(Benchmarker):
|
| 55 |
+
"""C-Eval benchmark implementation."""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
|
| 59 |
+
):
|
| 60 |
+
if subset is None:
|
| 61 |
+
subset = "all"
|
| 62 |
+
super().__init__(num_samples, subset)
|
| 63 |
+
|
| 64 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[str]]:
|
| 65 |
+
"""Load and preprocess C-Eval dataset."""
|
| 66 |
+
all_configs = [
|
| 67 |
+
"accountant",
|
| 68 |
+
"advanced_mathematics",
|
| 69 |
+
"art_studies",
|
| 70 |
+
"basic_medicine",
|
| 71 |
+
"business_administration",
|
| 72 |
+
"chinese_language_and_literature",
|
| 73 |
+
"civil_servant",
|
| 74 |
+
"clinical_medicine",
|
| 75 |
+
"college_chemistry",
|
| 76 |
+
"college_economics",
|
| 77 |
+
"college_physics",
|
| 78 |
+
"college_programming",
|
| 79 |
+
"computer_architecture",
|
| 80 |
+
"computer_network",
|
| 81 |
+
"discrete_mathematics",
|
| 82 |
+
"education_science",
|
| 83 |
+
"electrical_engineer",
|
| 84 |
+
"environmental_impact_assessment_engineer",
|
| 85 |
+
"fire_engineer",
|
| 86 |
+
"high_school_biology",
|
| 87 |
+
"high_school_chemistry",
|
| 88 |
+
"high_school_chinese",
|
| 89 |
+
"high_school_geography",
|
| 90 |
+
"high_school_history",
|
| 91 |
+
"high_school_mathematics",
|
| 92 |
+
"high_school_physics",
|
| 93 |
+
"high_school_politics",
|
| 94 |
+
"ideological_and_moral_cultivation",
|
| 95 |
+
"law",
|
| 96 |
+
"legal_professional",
|
| 97 |
+
"logic",
|
| 98 |
+
"mao_zedong_thought",
|
| 99 |
+
"marxism",
|
| 100 |
+
"metrology_engineer",
|
| 101 |
+
"middle_school_biology",
|
| 102 |
+
"middle_school_chemistry",
|
| 103 |
+
"middle_school_geography",
|
| 104 |
+
"middle_school_history",
|
| 105 |
+
"middle_school_mathematics",
|
| 106 |
+
"middle_school_physics",
|
| 107 |
+
"middle_school_politics",
|
| 108 |
+
"modern_chinese_history",
|
| 109 |
+
"operating_system",
|
| 110 |
+
"physician",
|
| 111 |
+
"plant_protection",
|
| 112 |
+
"probability_and_statistics",
|
| 113 |
+
"professional_tour_guide",
|
| 114 |
+
"sports_science",
|
| 115 |
+
"tax_accountant",
|
| 116 |
+
"teacher_qualification",
|
| 117 |
+
"urban_and_rural_planner",
|
| 118 |
+
"veterinary_medicine",
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
# Select configs to load
|
| 122 |
+
if self.subset == "all":
|
| 123 |
+
configs_to_load = all_configs
|
| 124 |
+
else:
|
| 125 |
+
for subset in self.subset:
|
| 126 |
+
assert (
|
| 127 |
+
subset in all_configs
|
| 128 |
+
), f"Subset {subset} not found in C-Eval dataset"
|
| 129 |
+
configs_to_load = self.subset
|
| 130 |
+
|
| 131 |
+
# Load datasets
|
| 132 |
+
try:
|
| 133 |
+
datasets = []
|
| 134 |
+
for config in configs_to_load:
|
| 135 |
+
try:
|
| 136 |
+
ds = load_dataset("ceval/ceval-exam", name=config, split="test")
|
| 137 |
+
datasets.append(ds)
|
| 138 |
+
print(f"Loaded config '{config}' with {len(ds)} samples")
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"Warning: Failed to load config '{config}': {e}")
|
| 141 |
+
if len(datasets) == 0:
|
| 142 |
+
raise ValueError("No configs could be loaded")
|
| 143 |
+
dataset = concatenate_datasets(datasets)
|
| 144 |
+
print(
|
| 145 |
+
f"Successfully loaded C-Eval dataset with all configs (total: {len(dataset)} samples)"
|
| 146 |
+
)
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(e)
|
| 149 |
+
print(f"Failed to load C-Eval dataset from 'ceval/ceval-exam': {e}")
|
| 150 |
+
print("Please ensure the dataset is available or install it manually.")
|
| 151 |
+
print("You can try: pip install datasets")
|
| 152 |
+
print("Or download from: https://huggingface.co/datasets/ceval/ceval-exam")
|
| 153 |
+
return [], []
|
| 154 |
+
|
| 155 |
+
# Process questions
|
| 156 |
+
questions = []
|
| 157 |
+
labels = []
|
| 158 |
+
for idx, item in enumerate(dataset):
|
| 159 |
+
if self.num_samples is not None and idx >= self.num_samples:
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
# Handle different dataset formats
|
| 163 |
+
question_text = None
|
| 164 |
+
if "question" in item:
|
| 165 |
+
question_text = item["question"]
|
| 166 |
+
elif "inputs" in item:
|
| 167 |
+
question_text = item["inputs"]
|
| 168 |
+
elif "problem" in item:
|
| 169 |
+
question_text = item["problem"]
|
| 170 |
+
elif "content" in item:
|
| 171 |
+
question_text = item["content"]
|
| 172 |
+
|
| 173 |
+
if not question_text:
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
# Get options - C-Eval typically has options as a list or dict
|
| 177 |
+
options = None
|
| 178 |
+
if "options" in item:
|
| 179 |
+
options = item["options"]
|
| 180 |
+
if isinstance(options, dict):
|
| 181 |
+
# Convert dict to list in order A, B, C, D
|
| 182 |
+
options = [
|
| 183 |
+
options.get("A", ""),
|
| 184 |
+
options.get("B", ""),
|
| 185 |
+
options.get("C", ""),
|
| 186 |
+
options.get("D", ""),
|
| 187 |
+
]
|
| 188 |
+
elif isinstance(options, list):
|
| 189 |
+
# Ensure we have 4 options
|
| 190 |
+
while len(options) < 4:
|
| 191 |
+
options.append("")
|
| 192 |
+
elif "choices" in item:
|
| 193 |
+
options = item["choices"]
|
| 194 |
+
if isinstance(options, dict):
|
| 195 |
+
options = [
|
| 196 |
+
options.get("A", ""),
|
| 197 |
+
options.get("B", ""),
|
| 198 |
+
options.get("C", ""),
|
| 199 |
+
options.get("D", ""),
|
| 200 |
+
]
|
| 201 |
+
else:
|
| 202 |
+
# Try to construct options from A, B, C, D fields
|
| 203 |
+
options = [
|
| 204 |
+
item.get("A", item.get("option_A", "")),
|
| 205 |
+
item.get("B", item.get("option_B", "")),
|
| 206 |
+
item.get("C", item.get("option_C", "")),
|
| 207 |
+
item.get("D", item.get("option_D", "")),
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
# Filter out empty options
|
| 211 |
+
if options:
|
| 212 |
+
options = [str(opt).strip() for opt in options if opt]
|
| 213 |
+
if len(options) < 2: # Need at least 2 options
|
| 214 |
+
continue
|
| 215 |
+
else:
|
| 216 |
+
continue
|
| 217 |
+
|
| 218 |
+
# Get answer
|
| 219 |
+
answer = None
|
| 220 |
+
if "answer" in item:
|
| 221 |
+
answer = str(item["answer"]).upper().strip()
|
| 222 |
+
elif "target" in item:
|
| 223 |
+
answer = str(item["target"]).upper().strip()
|
| 224 |
+
elif "label" in item:
|
| 225 |
+
answer = str(item["label"]).upper().strip()
|
| 226 |
+
elif "correct" in item:
|
| 227 |
+
answer = str(item["correct"]).upper().strip()
|
| 228 |
+
|
| 229 |
+
# Validate answer
|
| 230 |
+
if answer and answer in ["A", "B", "C", "D"]:
|
| 231 |
+
# Format question
|
| 232 |
+
formatted_question = format_question(question_text, options)
|
| 233 |
+
questions.append({"question": formatted_question})
|
| 234 |
+
labels.append(answer)
|
| 235 |
+
|
| 236 |
+
if len(questions) == 0:
|
| 237 |
+
print("No valid questions found. Please check the dataset format.")
|
| 238 |
+
print(
|
| 239 |
+
"Sample item keys:",
|
| 240 |
+
list(dataset[0].keys()) if len(dataset) > 0 else "No items",
|
| 241 |
+
)
|
| 242 |
+
return [], []
|
| 243 |
+
|
| 244 |
+
return questions, labels
|
| 245 |
+
|
| 246 |
+
def create_sgl_function(self):
|
| 247 |
+
"""Create SGL function for C-Eval."""
|
| 248 |
+
return create_simple_sgl_function(
|
| 249 |
+
function_name="get_ceval_answer",
|
| 250 |
+
answer_key="answer",
|
| 251 |
+
max_tokens=self.get_max_new_tokens(),
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
def extract_answer(self, output: str, label: Any = None) -> str:
|
| 255 |
+
"""Extract answer choice from model output."""
|
| 256 |
+
return extract_answer(output)
|
| 257 |
+
|
| 258 |
+
def compute_accuracy(self, predictions: List[str], labels: List[str]) -> float:
|
| 259 |
+
"""Compute accuracy metric."""
|
| 260 |
+
correct = 0
|
| 261 |
+
valid_count = 0
|
| 262 |
+
for i in range(len(predictions)):
|
| 263 |
+
if predictions[i] is not None: # Only count valid predictions
|
| 264 |
+
valid_count += 1
|
| 265 |
+
if predictions[i] == labels[i]:
|
| 266 |
+
correct += 1
|
| 267 |
+
return correct / valid_count if valid_count > 0 else 0.0
|
progress/SpecForge/benchmarks/benchmarker/financeqa.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
|
| 5 |
+
from .base import Benchmarker
|
| 6 |
+
from .registry import BENCHMARKS
|
| 7 |
+
from .utils import create_simple_sgl_function
|
| 8 |
+
|
| 9 |
+
QUESTION_PROMPT = """
|
| 10 |
+
Given the following context:
|
| 11 |
+
|
| 12 |
+
{context}
|
| 13 |
+
|
| 14 |
+
Can you answer the following question?
|
| 15 |
+
|
| 16 |
+
{question}
|
| 17 |
+
""".strip()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def generate_question(row: Dict[str, Any]) -> str:
|
| 21 |
+
if row["context"] is None:
|
| 22 |
+
return row["question"].strip()
|
| 23 |
+
else:
|
| 24 |
+
question = QUESTION_PROMPT.format(
|
| 25 |
+
context=row["context"].strip(),
|
| 26 |
+
question=row["question"].strip(),
|
| 27 |
+
)
|
| 28 |
+
return question
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@BENCHMARKS.register("financeqa")
|
| 32 |
+
class FinanceQABenchmarker(Benchmarker):
|
| 33 |
+
"""FinanceQA benchmark implementation."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 36 |
+
super().__init__(num_samples, None)
|
| 37 |
+
|
| 38 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
|
| 39 |
+
"""Load and preprocess FinanceQA dataset."""
|
| 40 |
+
# Read data
|
| 41 |
+
ds = load_dataset("AfterQuery/FinanceQA")["test"]
|
| 42 |
+
|
| 43 |
+
questions = []
|
| 44 |
+
labels = []
|
| 45 |
+
for i in range((len(ds))):
|
| 46 |
+
if self.num_samples is not None and i >= self.num_samples:
|
| 47 |
+
break
|
| 48 |
+
|
| 49 |
+
question_text = generate_question(ds[i])
|
| 50 |
+
questions.append({"question": question_text})
|
| 51 |
+
labels.append(None)
|
| 52 |
+
return questions, labels
|
| 53 |
+
|
| 54 |
+
def create_sgl_function(self):
|
| 55 |
+
return create_simple_sgl_function(
|
| 56 |
+
function_name="get_financeqa_answer",
|
| 57 |
+
answer_key="answer",
|
| 58 |
+
max_tokens=self.get_max_new_tokens(),
|
| 59 |
+
)
|
progress/SpecForge/benchmarks/benchmarker/gpqa.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
|
| 6 |
+
from .base import Benchmarker
|
| 7 |
+
from .registry import BENCHMARKS
|
| 8 |
+
from .utils import create_simple_sgl_function
|
| 9 |
+
|
| 10 |
+
GPQA_QUERY_TEMPLATE = """
|
| 11 |
+
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
| 12 |
+
|
| 13 |
+
{Question}
|
| 14 |
+
|
| 15 |
+
A) {A}
|
| 16 |
+
B) {B}
|
| 17 |
+
C) {C}
|
| 18 |
+
D) {D}
|
| 19 |
+
""".strip()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def generate_question(row: Dict[str, Any]) -> str:
|
| 23 |
+
gold_index = random.randint(0, 3)
|
| 24 |
+
choices = [
|
| 25 |
+
row["Incorrect Answer 1"],
|
| 26 |
+
row["Incorrect Answer 2"],
|
| 27 |
+
row["Incorrect Answer 3"],
|
| 28 |
+
]
|
| 29 |
+
choices.insert(gold_index, row["Correct Answer"])
|
| 30 |
+
|
| 31 |
+
question = GPQA_QUERY_TEMPLATE.format(
|
| 32 |
+
Question=row["Question"].strip(),
|
| 33 |
+
A=choices[0].strip(),
|
| 34 |
+
B=choices[1].strip(),
|
| 35 |
+
C=choices[2].strip(),
|
| 36 |
+
D=choices[3].strip(),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# 0 means A, 1 means B, 2 means C, 3 means D
|
| 40 |
+
answer = ["A", "B", "C", "D"][gold_index]
|
| 41 |
+
return question, answer
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@BENCHMARKS.register("gpqa")
|
| 45 |
+
class GPQABenchmarker(Benchmarker):
|
| 46 |
+
"""GPQA benchmark implementation."""
|
| 47 |
+
|
| 48 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 49 |
+
super().__init__(num_samples, None)
|
| 50 |
+
|
| 51 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
|
| 52 |
+
"""Load and preprocess GPQA dataset."""
|
| 53 |
+
# Read data
|
| 54 |
+
ds = load_dataset("Idavidrein/gpqa", "gpqa_main")["train"]
|
| 55 |
+
|
| 56 |
+
questions = []
|
| 57 |
+
labels = []
|
| 58 |
+
for i in range((len(ds))):
|
| 59 |
+
if self.num_samples is not None and i >= self.num_samples:
|
| 60 |
+
break
|
| 61 |
+
|
| 62 |
+
question_text, answer = generate_question(ds[i])
|
| 63 |
+
questions.append({"question": question_text})
|
| 64 |
+
labels.append(answer)
|
| 65 |
+
return questions, labels
|
| 66 |
+
|
| 67 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
|
| 68 |
+
if "Answer: " not in output:
|
| 69 |
+
return None
|
| 70 |
+
return output.split("Answer: ")[1].strip()
|
| 71 |
+
|
| 72 |
+
def compute_accuracy(
|
| 73 |
+
self, predictions: List[Any], labels: List[Any]
|
| 74 |
+
) -> Optional[float]:
|
| 75 |
+
if not labels or len(labels) == 0:
|
| 76 |
+
return None
|
| 77 |
+
correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
|
| 78 |
+
return correct / len(labels) if len(labels) > 0 else 0.0
|
| 79 |
+
|
| 80 |
+
def create_sgl_function(self):
|
| 81 |
+
return create_simple_sgl_function(
|
| 82 |
+
function_name="get_gpqa_mcq_answer",
|
| 83 |
+
answer_key="answer",
|
| 84 |
+
max_tokens=self.get_max_new_tokens(),
|
| 85 |
+
)
|
progress/SpecForge/benchmarks/benchmarker/gsm8k.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GSM8K benchmark evaluation script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import ast
|
| 6 |
+
import re
|
| 7 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
from sglang.utils import download_and_cache_file, read_jsonl
|
| 10 |
+
|
| 11 |
+
from .base import Benchmarker
|
| 12 |
+
from .registry import BENCHMARKS
|
| 13 |
+
from .utils import create_few_shot_sgl_function
|
| 14 |
+
|
| 15 |
+
INVALID = -9999999
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_one_example(lines: List[Dict], i: int, include_answer: bool) -> str:
|
| 19 |
+
"""Format a single example."""
|
| 20 |
+
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
| 21 |
+
if include_answer:
|
| 22 |
+
ret += " " + lines[i]["answer"]
|
| 23 |
+
return ret
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_few_shot_examples(lines: List[Dict], k: int) -> str:
|
| 27 |
+
"""Get few-shot examples as a string."""
|
| 28 |
+
ret = ""
|
| 29 |
+
for i in range(k):
|
| 30 |
+
ret += get_one_example(lines, i, True) + "\n\n"
|
| 31 |
+
return ret
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_answer_value(answer_str: str) -> int:
|
| 35 |
+
"""Extract numeric answer from model output."""
|
| 36 |
+
answer_str = answer_str.replace(",", "")
|
| 37 |
+
numbers = re.findall(r"\d+", answer_str)
|
| 38 |
+
if len(numbers) < 1:
|
| 39 |
+
return INVALID
|
| 40 |
+
try:
|
| 41 |
+
return ast.literal_eval(numbers[-1])
|
| 42 |
+
except SyntaxError:
|
| 43 |
+
return INVALID
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@BENCHMARKS.register("gsm8k")
|
| 47 |
+
class GSM8KBenchmarker(Benchmarker):
|
| 48 |
+
"""GSM8K benchmark implementation."""
|
| 49 |
+
|
| 50 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 51 |
+
super().__init__(num_samples, None)
|
| 52 |
+
|
| 53 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
|
| 54 |
+
"""Load and preprocess GSM8K dataset."""
|
| 55 |
+
# Read data
|
| 56 |
+
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
| 57 |
+
data_path = download_and_cache_file(url)
|
| 58 |
+
lines = list(read_jsonl(data_path))
|
| 59 |
+
|
| 60 |
+
# Construct prompts
|
| 61 |
+
few_shot_examples = get_few_shot_examples(lines, 5)
|
| 62 |
+
|
| 63 |
+
questions = []
|
| 64 |
+
labels = []
|
| 65 |
+
for i in range((len(lines))):
|
| 66 |
+
if self.num_samples is not None and i >= self.num_samples:
|
| 67 |
+
break
|
| 68 |
+
|
| 69 |
+
question_text = get_one_example(lines, i, False)
|
| 70 |
+
questions.append({"question": question_text})
|
| 71 |
+
labels.append(get_answer_value(lines[i]["answer"]))
|
| 72 |
+
|
| 73 |
+
# Store few_shot_examples for use in create_sgl_function
|
| 74 |
+
self.few_shot_examples = few_shot_examples
|
| 75 |
+
|
| 76 |
+
assert all(l != INVALID for l in labels), "Some labels are invalid"
|
| 77 |
+
return questions, labels
|
| 78 |
+
|
| 79 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
|
| 80 |
+
"""Extract numeric answer from model output."""
|
| 81 |
+
return get_answer_value(output)
|
| 82 |
+
|
| 83 |
+
def compute_accuracy(
|
| 84 |
+
self, predictions: List[Any], labels: List[Any]
|
| 85 |
+
) -> Optional[float]:
|
| 86 |
+
"""Compute accuracy for GSM8K by comparing numeric answers."""
|
| 87 |
+
if not labels or len(labels) == 0:
|
| 88 |
+
return None
|
| 89 |
+
correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
|
| 90 |
+
return correct / len(labels) if len(labels) > 0 else 0.0
|
| 91 |
+
|
| 92 |
+
def create_sgl_function(self):
|
| 93 |
+
"""Create SGL function for GSM8K with few-shot examples."""
|
| 94 |
+
return create_few_shot_sgl_function(
|
| 95 |
+
few_shot_examples=self.few_shot_examples,
|
| 96 |
+
function_name="few_shot_gsm8k",
|
| 97 |
+
answer_key="answer",
|
| 98 |
+
stop=["Question", "Assistant:", "<|separator|>"],
|
| 99 |
+
)
|
progress/SpecForge/benchmarks/benchmarker/humaneval.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HumanEval benchmark evaluation script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
from .base import Benchmarker
|
| 11 |
+
from .registry import BENCHMARKS
|
| 12 |
+
from .utils import create_simple_sgl_function
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def extract_code_from_output(output: str) -> Optional[str]:
|
| 16 |
+
"""Extract Python code from model output.
|
| 17 |
+
|
| 18 |
+
Tries to extract code blocks or function definitions.
|
| 19 |
+
"""
|
| 20 |
+
# Try to find code in markdown code blocks
|
| 21 |
+
code_block_pattern = r"```(?:python)?\n(.*?)```"
|
| 22 |
+
match = re.search(code_block_pattern, output, re.DOTALL)
|
| 23 |
+
if match:
|
| 24 |
+
return match.group(1).strip()
|
| 25 |
+
|
| 26 |
+
# Try to find function definition (common in HumanEval)
|
| 27 |
+
# Look for "def " followed by code until the next def or end of string
|
| 28 |
+
def_pattern = r"(def\s+\w+\([^)]*\):.*?)(?=\n\ndef\s+|\Z)"
|
| 29 |
+
match = re.search(def_pattern, output, re.DOTALL)
|
| 30 |
+
if match:
|
| 31 |
+
return match.group(1).strip()
|
| 32 |
+
|
| 33 |
+
# Fallback: return the output as-is (might already be code)
|
| 34 |
+
return output.strip() if output.strip() else None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def check_code_passes_tests(code: str, test_code: str, entry_point: str) -> bool:
|
| 38 |
+
"""Check if generated code passes the test cases.
|
| 39 |
+
|
| 40 |
+
This is a simplified version. For full evaluation, use the official
|
| 41 |
+
HumanEval evaluation framework.
|
| 42 |
+
|
| 43 |
+
HumanEval test code typically contains assertions that will raise
|
| 44 |
+
AssertionError if the code doesn't pass. If execution completes without
|
| 45 |
+
exceptions, the tests pass.
|
| 46 |
+
"""
|
| 47 |
+
try:
|
| 48 |
+
# Create a safe execution environment
|
| 49 |
+
namespace = {}
|
| 50 |
+
# Execute the code (function definition)
|
| 51 |
+
exec(code, namespace)
|
| 52 |
+
# Execute the test code (which contains assertions)
|
| 53 |
+
# If no exception is raised, the tests pass
|
| 54 |
+
exec(test_code, namespace)
|
| 55 |
+
return True
|
| 56 |
+
except AssertionError:
|
| 57 |
+
# Assertion failed - test didn't pass
|
| 58 |
+
return False
|
| 59 |
+
except Exception:
|
| 60 |
+
# Any other exception (syntax error, runtime error, etc.) means test failed
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@BENCHMARKS.register("humaneval")
|
| 65 |
+
class HumanEvalBenchmarker(Benchmarker):
|
| 66 |
+
"""HumanEval benchmark implementation."""
|
| 67 |
+
|
| 68 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 69 |
+
"""Initialize benchmark and store test cases."""
|
| 70 |
+
super().__init__(num_samples, None)
|
| 71 |
+
self.test_cases = []
|
| 72 |
+
self.entry_points = []
|
| 73 |
+
|
| 74 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[Dict[str, str]]]]:
|
| 75 |
+
"""Load and preprocess HumanEval dataset."""
|
| 76 |
+
dataset = load_dataset("openai/openai_humaneval")["test"]
|
| 77 |
+
questions = []
|
| 78 |
+
labels = []
|
| 79 |
+
self.test_cases = []
|
| 80 |
+
self.entry_points = []
|
| 81 |
+
|
| 82 |
+
for idx, q in enumerate(dataset):
|
| 83 |
+
if self.num_samples is not None and idx >= self.num_samples:
|
| 84 |
+
break
|
| 85 |
+
|
| 86 |
+
questions.append({"question": q["prompt"]})
|
| 87 |
+
|
| 88 |
+
# Store test case and entry point for evaluation
|
| 89 |
+
test_code = q.get("test", "")
|
| 90 |
+
entry_point = q.get("entry_point", "")
|
| 91 |
+
self.test_cases.append(test_code)
|
| 92 |
+
self.entry_points.append(entry_point)
|
| 93 |
+
|
| 94 |
+
# Store canonical solution as reference (optional, for comparison)
|
| 95 |
+
canonical_solution = q.get("canonical_solution", "")
|
| 96 |
+
labels.append(
|
| 97 |
+
{
|
| 98 |
+
"test": test_code,
|
| 99 |
+
"entry_point": entry_point,
|
| 100 |
+
"canonical_solution": canonical_solution,
|
| 101 |
+
}
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return questions, labels
|
| 105 |
+
|
| 106 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
|
| 107 |
+
"""Extract code from model output."""
|
| 108 |
+
return extract_code_from_output(output)
|
| 109 |
+
|
| 110 |
+
def compute_accuracy(
|
| 111 |
+
self, predictions: List[Any], labels: List[Any]
|
| 112 |
+
) -> Optional[float]:
|
| 113 |
+
"""Compute accuracy for HumanEval by checking if code passes tests.
|
| 114 |
+
|
| 115 |
+
Note: This is a simplified evaluation. For official pass@k metrics,
|
| 116 |
+
use the HumanEval evaluation framework.
|
| 117 |
+
"""
|
| 118 |
+
if not labels or len(labels) == 0:
|
| 119 |
+
return None
|
| 120 |
+
if all(label is None for label in labels):
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
correct = 0
|
| 124 |
+
valid_count = 0
|
| 125 |
+
|
| 126 |
+
for i, (pred, label) in enumerate(zip(predictions, labels)):
|
| 127 |
+
if label is not None and isinstance(label, dict):
|
| 128 |
+
valid_count += 1
|
| 129 |
+
if pred is not None:
|
| 130 |
+
try:
|
| 131 |
+
# Get the prompt (function signature and docstring)
|
| 132 |
+
prompt = self.questions[i]["question"]
|
| 133 |
+
entry_point = label.get("entry_point", "")
|
| 134 |
+
|
| 135 |
+
# The prompt contains the function signature (e.g., "def function_name(...):")
|
| 136 |
+
# The generated code might be:
|
| 137 |
+
# 1. Just the function body (what we want) - need to combine with prompt
|
| 138 |
+
# 2. The complete function including signature - use as-is
|
| 139 |
+
# 3. Code in markdown blocks - already extracted by extract_code_from_output
|
| 140 |
+
|
| 141 |
+
pred_str = str(pred).strip()
|
| 142 |
+
|
| 143 |
+
# Check if pred already contains a complete function definition
|
| 144 |
+
# (starts with "def " and contains the entry_point function name)
|
| 145 |
+
if pred_str.startswith("def ") and entry_point:
|
| 146 |
+
# Check if this is the same function (by name)
|
| 147 |
+
func_name_match = re.match(r"def\s+(\w+)\s*\(", pred_str)
|
| 148 |
+
if (
|
| 149 |
+
func_name_match
|
| 150 |
+
and func_name_match.group(1) == entry_point
|
| 151 |
+
):
|
| 152 |
+
# Generated code includes complete function, use it as-is
|
| 153 |
+
full_code = pred_str
|
| 154 |
+
else:
|
| 155 |
+
# Different function or no match, combine with prompt
|
| 156 |
+
full_code = prompt + "\n" + pred_str
|
| 157 |
+
elif pred_str.startswith("def "):
|
| 158 |
+
# Has function definition but we can't verify entry_point, use as-is
|
| 159 |
+
full_code = pred_str
|
| 160 |
+
else:
|
| 161 |
+
# Generated code is just the body, combine with prompt
|
| 162 |
+
full_code = prompt + "\n" + pred_str
|
| 163 |
+
|
| 164 |
+
# Check if code passes tests
|
| 165 |
+
test_code = label.get("test", "")
|
| 166 |
+
|
| 167 |
+
if test_code and check_code_passes_tests(
|
| 168 |
+
full_code, test_code, entry_point
|
| 169 |
+
):
|
| 170 |
+
correct += 1
|
| 171 |
+
except Exception as e:
|
| 172 |
+
# If evaluation fails, consider it incorrect
|
| 173 |
+
# Uncomment for debugging: print(f"Error evaluating code {i}: {e}")
|
| 174 |
+
pass
|
| 175 |
+
|
| 176 |
+
return correct / valid_count if valid_count > 0 else 0.0
|
| 177 |
+
|
| 178 |
+
def create_sgl_function(self):
|
| 179 |
+
"""Create SGL function for HumanEval."""
|
| 180 |
+
return create_simple_sgl_function(
|
| 181 |
+
function_name="get_humaneval_answer",
|
| 182 |
+
answer_key="answer",
|
| 183 |
+
max_tokens=self.get_max_new_tokens(),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def get_max_new_tokens(self) -> int:
|
| 187 |
+
"""HumanEval code generation requires more tokens."""
|
| 188 |
+
return 1024
|
progress/SpecForge/benchmarks/benchmarker/livecodebench.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GSM8K benchmark evaluation script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
|
| 9 |
+
from .base import Benchmarker
|
| 10 |
+
from .registry import BENCHMARKS
|
| 11 |
+
from .utils import create_simple_sgl_function
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def generate_question(row: Dict[str, Any]) -> str:
|
| 15 |
+
question = row["question_content"].strip()
|
| 16 |
+
return question
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@BENCHMARKS.register("livecodebench")
|
| 20 |
+
class LCBBenchmarker(Benchmarker):
|
| 21 |
+
"""LiveCodeBench benchmark implementation."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 24 |
+
super().__init__(num_samples, None)
|
| 25 |
+
|
| 26 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
|
| 27 |
+
# Read data
|
| 28 |
+
ds = load_dataset("livecodebench/code_generation")["test"]
|
| 29 |
+
|
| 30 |
+
questions = []
|
| 31 |
+
labels = []
|
| 32 |
+
for i in range((len(ds))):
|
| 33 |
+
if self.num_samples is not None and i >= self.num_samples:
|
| 34 |
+
break
|
| 35 |
+
|
| 36 |
+
question_text = generate_question(ds[i])
|
| 37 |
+
questions.append({"question": question_text})
|
| 38 |
+
labels.append(None)
|
| 39 |
+
return questions, labels
|
| 40 |
+
|
| 41 |
+
def create_sgl_function(self):
|
| 42 |
+
return create_simple_sgl_function(
|
| 43 |
+
function_name="get_livecodebench_answer",
|
| 44 |
+
answer_key="answer",
|
| 45 |
+
max_tokens=self.get_max_new_tokens(),
|
| 46 |
+
)
|
progress/SpecForge/benchmarks/benchmarker/math500.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MATH-500 benchmark evaluation script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
from .base import Benchmarker
|
| 11 |
+
from .registry import BENCHMARKS
|
| 12 |
+
from .utils import create_simple_sgl_function
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def extract_math_answer(output: str) -> Optional[str]:
|
| 16 |
+
"""Extract final answer from math problem solution.
|
| 17 |
+
|
| 18 |
+
Tries to extract answer from \boxed{} format first, then looks for
|
| 19 |
+
the last number in the output.
|
| 20 |
+
"""
|
| 21 |
+
# Try to find answer in \boxed{} format
|
| 22 |
+
boxed_pattern = r"\\boxed\{([^}]+)\}"
|
| 23 |
+
match = re.search(boxed_pattern, output)
|
| 24 |
+
if match:
|
| 25 |
+
return match.group(1).strip()
|
| 26 |
+
|
| 27 |
+
# Try to find answer in \boxed format (without braces)
|
| 28 |
+
boxed_pattern2 = r"\\boxed\s+([^\s]+)"
|
| 29 |
+
match = re.search(boxed_pattern2, output)
|
| 30 |
+
if match:
|
| 31 |
+
return match.group(1).strip()
|
| 32 |
+
|
| 33 |
+
# Try to find the last number (could be integer or decimal)
|
| 34 |
+
# Look for patterns like "The answer is 42" or "Answer: 3.14"
|
| 35 |
+
answer_patterns = [
|
| 36 |
+
r"(?:answer|Answer|ANSWER)[\s:]+([-+]?\d*\.?\d+)",
|
| 37 |
+
r"(?:is|equals?|=\s*)([-+]?\d*\.?\d+)\s*$",
|
| 38 |
+
]
|
| 39 |
+
for pattern in answer_patterns:
|
| 40 |
+
matches = re.findall(pattern, output, re.IGNORECASE)
|
| 41 |
+
if matches:
|
| 42 |
+
return matches[-1].strip()
|
| 43 |
+
|
| 44 |
+
# Fallback: extract the last number in the text
|
| 45 |
+
numbers = re.findall(r"[-+]?\d*\.?\d+", output)
|
| 46 |
+
if numbers:
|
| 47 |
+
return numbers[-1]
|
| 48 |
+
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@BENCHMARKS.register("math500")
|
| 53 |
+
class Math500Benchmarker(Benchmarker):
|
| 54 |
+
"""MATH-500 benchmark implementation."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 57 |
+
super().__init__(num_samples, None)
|
| 58 |
+
|
| 59 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
|
| 60 |
+
"""Load and preprocess MATH-500 dataset."""
|
| 61 |
+
dataset = load_dataset("HuggingFaceH4/MATH-500")["test"]
|
| 62 |
+
questions = []
|
| 63 |
+
labels = []
|
| 64 |
+
for idx, q in enumerate(dataset):
|
| 65 |
+
if self.num_samples is not None and idx >= self.num_samples:
|
| 66 |
+
break
|
| 67 |
+
|
| 68 |
+
questions.append({"question": q["problem"]})
|
| 69 |
+
# Extract answer from solution or answer field
|
| 70 |
+
answer = None
|
| 71 |
+
if "answer" in q:
|
| 72 |
+
answer = str(q["answer"]).strip()
|
| 73 |
+
elif "solution" in q:
|
| 74 |
+
# Try to extract from solution
|
| 75 |
+
answer = extract_math_answer(q["solution"])
|
| 76 |
+
labels.append(answer)
|
| 77 |
+
return questions, labels
|
| 78 |
+
|
| 79 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
|
| 80 |
+
"""Extract answer from model output."""
|
| 81 |
+
return extract_math_answer(output)
|
| 82 |
+
|
| 83 |
+
def compute_accuracy(
|
| 84 |
+
self, predictions: List[Any], labels: List[Any]
|
| 85 |
+
) -> Optional[float]:
|
| 86 |
+
"""Compute accuracy for MATH-500 by comparing answers."""
|
| 87 |
+
if not labels or len(labels) == 0:
|
| 88 |
+
return None
|
| 89 |
+
if all(label is None for label in labels):
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
correct = 0
|
| 93 |
+
valid_count = 0
|
| 94 |
+
for pred, label in zip(predictions, labels):
|
| 95 |
+
if label is not None:
|
| 96 |
+
valid_count += 1
|
| 97 |
+
if pred is not None:
|
| 98 |
+
# Normalize answers for comparison (remove whitespace, handle different formats)
|
| 99 |
+
pred_normalized = str(pred).strip().lower()
|
| 100 |
+
label_normalized = str(label).strip().lower()
|
| 101 |
+
# Try exact match first
|
| 102 |
+
if pred_normalized == label_normalized:
|
| 103 |
+
correct += 1
|
| 104 |
+
else:
|
| 105 |
+
# Try numeric comparison if both are numbers
|
| 106 |
+
try:
|
| 107 |
+
pred_num = float(pred_normalized)
|
| 108 |
+
label_num = float(label_normalized)
|
| 109 |
+
if abs(pred_num - label_num) < 1e-6:
|
| 110 |
+
correct += 1
|
| 111 |
+
except ValueError:
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
return correct / valid_count if valid_count > 0 else 0.0
|
| 115 |
+
|
| 116 |
+
def create_sgl_function(self):
|
| 117 |
+
"""Create SGL function for MATH-500."""
|
| 118 |
+
return create_simple_sgl_function(
|
| 119 |
+
function_name="get_math500_answer",
|
| 120 |
+
answer_key="answer",
|
| 121 |
+
max_tokens=self.get_max_new_tokens(),
|
| 122 |
+
)
|
progress/SpecForge/benchmarks/benchmarker/mmlu.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
|
| 5 |
+
from .base import Benchmarker
|
| 6 |
+
from .registry import BENCHMARKS
|
| 7 |
+
from .utils import create_simple_sgl_function
|
| 8 |
+
|
| 9 |
+
GPQA_QUERY_TEMPLATE = """
|
| 10 |
+
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
| 11 |
+
|
| 12 |
+
{Question}
|
| 13 |
+
|
| 14 |
+
A) {A}
|
| 15 |
+
B) {B}
|
| 16 |
+
C) {C}
|
| 17 |
+
D) {D}
|
| 18 |
+
""".strip()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def generate_question(row: Dict[str, Any]) -> str:
|
| 22 |
+
choices = row["choices"]
|
| 23 |
+
question = GPQA_QUERY_TEMPLATE.format(
|
| 24 |
+
Question=row["question"].strip(),
|
| 25 |
+
A=choices[0].strip(),
|
| 26 |
+
B=choices[1].strip(),
|
| 27 |
+
C=choices[2].strip(),
|
| 28 |
+
D=choices[3].strip(),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# 0 means A, 1 means B, 2 means C, 3 means D
|
| 32 |
+
answer = ["A", "B", "C", "D"][row["answer"]]
|
| 33 |
+
print(answer)
|
| 34 |
+
return question, answer
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@BENCHMARKS.register("mmlu")
|
| 38 |
+
class MMLUBenchmarker(Benchmarker):
|
| 39 |
+
"""MMLU benchmark implementation."""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
|
| 43 |
+
):
|
| 44 |
+
if subset is None:
|
| 45 |
+
subset = ["all"]
|
| 46 |
+
super().__init__(num_samples, subset)
|
| 47 |
+
|
| 48 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
|
| 49 |
+
# Read data
|
| 50 |
+
questions = []
|
| 51 |
+
labels = []
|
| 52 |
+
|
| 53 |
+
for subset in self.subset:
|
| 54 |
+
ds = load_dataset("cais/mmlu", subset)["test"]
|
| 55 |
+
for i in range((len(ds))):
|
| 56 |
+
if self.num_samples is not None and i >= self.num_samples:
|
| 57 |
+
break
|
| 58 |
+
|
| 59 |
+
question_text, answer = generate_question(ds[i])
|
| 60 |
+
questions.append({"question": question_text})
|
| 61 |
+
labels.append(answer)
|
| 62 |
+
return questions, labels
|
| 63 |
+
|
| 64 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
|
| 65 |
+
if "Answer: " not in output:
|
| 66 |
+
return None
|
| 67 |
+
return output.split("Answer: ")[1].strip()
|
| 68 |
+
|
| 69 |
+
def compute_accuracy(
|
| 70 |
+
self, predictions: List[Any], labels: List[Any]
|
| 71 |
+
) -> Optional[float]:
|
| 72 |
+
if not labels or len(labels) == 0:
|
| 73 |
+
return None
|
| 74 |
+
correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
|
| 75 |
+
return correct / len(labels) if len(labels) > 0 else 0.0
|
| 76 |
+
|
| 77 |
+
def create_sgl_function(self):
|
| 78 |
+
return create_simple_sgl_function(
|
| 79 |
+
function_name="get_mmlu_answer",
|
| 80 |
+
answer_key="answer",
|
| 81 |
+
max_tokens=self.get_max_new_tokens(),
|
| 82 |
+
)
|
progress/SpecForge/benchmarks/benchmarker/mmstar.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MMStar benchmark evaluation script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import shutil
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
|
| 12 |
+
from .base import Benchmarker
|
| 13 |
+
from .registry import BENCHMARKS
|
| 14 |
+
from .utils import create_image_sgl_function
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def extract_mmstar_answer(
|
| 18 |
+
output: str, options: Optional[List[str]] = None
|
| 19 |
+
) -> Optional[str]:
|
| 20 |
+
"""Extract answer from MMStar model output.
|
| 21 |
+
|
| 22 |
+
MMStar questions typically have multiple choice options (A, B, C, D, etc.)
|
| 23 |
+
"""
|
| 24 |
+
output_upper = output.strip().upper()
|
| 25 |
+
|
| 26 |
+
# Try to find answer choice (A, B, C, D, etc.)
|
| 27 |
+
# Direct match for single letter
|
| 28 |
+
match = re.search(r"\b([A-Z])\b", output_upper)
|
| 29 |
+
if match:
|
| 30 |
+
letter = match.group(1)
|
| 31 |
+
if options and len(options) > 0:
|
| 32 |
+
# Validate that the letter is within valid range
|
| 33 |
+
max_option = chr(64 + len(options)) # 'A' + (len-1)
|
| 34 |
+
if "A" <= letter <= max_option:
|
| 35 |
+
return letter
|
| 36 |
+
else:
|
| 37 |
+
# Assume A-D are valid
|
| 38 |
+
if "A" <= letter <= "D":
|
| 39 |
+
return letter
|
| 40 |
+
|
| 41 |
+
# Try to find answer in parentheses or brackets
|
| 42 |
+
for pattern in [
|
| 43 |
+
r"\(([A-Z])\)",
|
| 44 |
+
r"\[([A-Z])\]",
|
| 45 |
+
r"答案[::]\s*([A-Z])",
|
| 46 |
+
r"Answer[::]\s*([A-Z])",
|
| 47 |
+
r"选择[::]\s*([A-Z])",
|
| 48 |
+
]:
|
| 49 |
+
match = re.search(pattern, output_upper)
|
| 50 |
+
if match:
|
| 51 |
+
letter = match.group(1)
|
| 52 |
+
if options and len(options) > 0:
|
| 53 |
+
max_option = chr(64 + len(options))
|
| 54 |
+
if "A" <= letter <= max_option:
|
| 55 |
+
return letter
|
| 56 |
+
elif "A" <= letter <= "D":
|
| 57 |
+
return letter
|
| 58 |
+
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@BENCHMARKS.register("mmstar")
|
| 63 |
+
class MMStarBenchmarker(Benchmarker):
|
| 64 |
+
"""MMStar benchmark implementation."""
|
| 65 |
+
|
| 66 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 67 |
+
super().__init__(num_samples, None)
|
| 68 |
+
"""Initialize benchmark and set up cache directory."""
|
| 69 |
+
self.cache_dir = None
|
| 70 |
+
self.options_list = [] # Store options for each question
|
| 71 |
+
|
| 72 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
|
| 73 |
+
"""Load and preprocess MMStar dataset."""
|
| 74 |
+
self.cache_dir = os.path.join(".cache", "mmstar_specforge")
|
| 75 |
+
image_dir = os.path.join(self.cache_dir, "images")
|
| 76 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
| 77 |
+
os.makedirs(image_dir, exist_ok=True)
|
| 78 |
+
print(f"Created temporary image directory: {self.cache_dir}")
|
| 79 |
+
|
| 80 |
+
dataset = load_dataset("Lin-Chen/MMStar")["val"]
|
| 81 |
+
questions = []
|
| 82 |
+
labels = []
|
| 83 |
+
self.options_list = []
|
| 84 |
+
|
| 85 |
+
for idx, q in enumerate(dataset):
|
| 86 |
+
if self.num_samples is not None and idx >= self.num_samples:
|
| 87 |
+
break
|
| 88 |
+
|
| 89 |
+
image = q["image"]
|
| 90 |
+
image_path = os.path.join(self.cache_dir, q["meta_info"]["image_path"])
|
| 91 |
+
image.convert("RGB").save(image_path, "JPEG")
|
| 92 |
+
|
| 93 |
+
# Extract question and options
|
| 94 |
+
question_full = q["question"]
|
| 95 |
+
if "Options:" in question_full:
|
| 96 |
+
question_text, options_text = question_full.split("Options:", 1)
|
| 97 |
+
question_text = question_text.strip()
|
| 98 |
+
# Parse options (typically A. option1 B. option2 etc.)
|
| 99 |
+
options = []
|
| 100 |
+
for line in options_text.strip().split("\n"):
|
| 101 |
+
line = line.strip()
|
| 102 |
+
if line and re.match(r"^[A-Z]\.", line):
|
| 103 |
+
option_text = re.sub(r"^[A-Z]\.\s*", "", line).strip()
|
| 104 |
+
options.append(option_text)
|
| 105 |
+
self.options_list.append(options)
|
| 106 |
+
else:
|
| 107 |
+
question_text = question_full.strip()
|
| 108 |
+
self.options_list.append([])
|
| 109 |
+
|
| 110 |
+
item = {
|
| 111 |
+
"image_path": image_path,
|
| 112 |
+
"question": question_text,
|
| 113 |
+
}
|
| 114 |
+
questions.append(item)
|
| 115 |
+
|
| 116 |
+
# Extract ground truth answer
|
| 117 |
+
answer = None
|
| 118 |
+
if "answer" in q:
|
| 119 |
+
answer = str(q["answer"]).strip().upper()
|
| 120 |
+
elif "correct_answer" in q:
|
| 121 |
+
answer = str(q["correct_answer"]).strip().upper()
|
| 122 |
+
elif "ground_truth" in q:
|
| 123 |
+
answer = str(q["ground_truth"]).strip().upper()
|
| 124 |
+
|
| 125 |
+
# Validate answer is a valid option letter
|
| 126 |
+
if answer and len(answer) == 1 and "A" <= answer <= "Z":
|
| 127 |
+
if self.options_list[-1]:
|
| 128 |
+
max_option = chr(64 + len(self.options_list[-1]))
|
| 129 |
+
if answer <= max_option:
|
| 130 |
+
labels.append(answer)
|
| 131 |
+
else:
|
| 132 |
+
labels.append(None)
|
| 133 |
+
else:
|
| 134 |
+
labels.append(answer)
|
| 135 |
+
else:
|
| 136 |
+
labels.append(None)
|
| 137 |
+
|
| 138 |
+
return questions, labels
|
| 139 |
+
|
| 140 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
|
| 141 |
+
"""Extract answer from model output."""
|
| 142 |
+
# Use the options for the current question if available
|
| 143 |
+
# Note: We can't easily get the question index here, so we'll use a simpler approach
|
| 144 |
+
return extract_mmstar_answer(output)
|
| 145 |
+
|
| 146 |
+
def compute_accuracy(
|
| 147 |
+
self, predictions: List[Any], labels: List[Any]
|
| 148 |
+
) -> Optional[float]:
|
| 149 |
+
"""Compute accuracy for MMStar by comparing answer choices."""
|
| 150 |
+
if not labels or len(labels) == 0:
|
| 151 |
+
return None
|
| 152 |
+
if all(label is None for label in labels):
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
correct = 0
|
| 156 |
+
valid_count = 0
|
| 157 |
+
for pred, label in zip(predictions, labels):
|
| 158 |
+
if label is not None:
|
| 159 |
+
valid_count += 1
|
| 160 |
+
if pred is not None:
|
| 161 |
+
# Normalize to uppercase for comparison
|
| 162 |
+
pred_normalized = str(pred).strip().upper()
|
| 163 |
+
label_normalized = str(label).strip().upper()
|
| 164 |
+
if pred_normalized == label_normalized:
|
| 165 |
+
correct += 1
|
| 166 |
+
|
| 167 |
+
return correct / valid_count if valid_count > 0 else 0.0
|
| 168 |
+
|
| 169 |
+
def create_sgl_function(self):
|
| 170 |
+
"""Create SGL function for MMStar (image-based Q&A)."""
|
| 171 |
+
return create_image_sgl_function(
|
| 172 |
+
function_name="get_mmstar_answer",
|
| 173 |
+
answer_key="answer",
|
| 174 |
+
max_tokens=self.get_max_new_tokens(),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def run(self, *args, **kwargs):
|
| 178 |
+
"""Run benchmark and clean up cache directory."""
|
| 179 |
+
try:
|
| 180 |
+
return super().run(*args, **kwargs)
|
| 181 |
+
finally:
|
| 182 |
+
# Clean up cache directory
|
| 183 |
+
if self.cache_dir and os.path.exists(self.cache_dir):
|
| 184 |
+
shutil.rmtree(self.cache_dir)
|
| 185 |
+
print(f"Deleted temporary directory: {self.cache_dir}")
|
progress/SpecForge/benchmarks/benchmarker/mtbench.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MT-Bench benchmark evaluation script.
|
| 3 |
+
Adapted from https://github.com/chromecast56/sglang/blob/6f145d2eadb93a116134f703358ce76f15381045/benchmark/mtbench/bench_sglang.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from sglang.utils import download_and_cache_file, read_jsonl
|
| 9 |
+
|
| 10 |
+
from .base import Benchmarker
|
| 11 |
+
from .registry import BENCHMARKS
|
| 12 |
+
from .utils import create_multi_turn_sgl_function
|
| 13 |
+
|
| 14 |
+
SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@BENCHMARKS.register("mtbench")
|
| 18 |
+
class MTBenchBenchmarker(Benchmarker):
|
| 19 |
+
"""MT-Bench benchmark implementation."""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
|
| 23 |
+
):
|
| 24 |
+
# support categorical data for mtbench
|
| 25 |
+
if subset is None:
|
| 26 |
+
subset = ["all"]
|
| 27 |
+
super().__init__(num_samples, subset)
|
| 28 |
+
|
| 29 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[None]]:
|
| 30 |
+
"""Load and preprocess MT-Bench dataset."""
|
| 31 |
+
url = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl"
|
| 32 |
+
download_and_cache_file(url, filename="mtbench.jsonl")
|
| 33 |
+
questions_data = list(read_jsonl("mtbench.jsonl"))
|
| 34 |
+
questions_data = questions_data
|
| 35 |
+
|
| 36 |
+
questions = [
|
| 37 |
+
{"question_1": q["turns"][0], "question_2": q["turns"][1]}
|
| 38 |
+
for q in questions_data
|
| 39 |
+
]
|
| 40 |
+
# MT-Bench doesn't have labels for accuracy computation
|
| 41 |
+
labels = [None] * len(questions)
|
| 42 |
+
|
| 43 |
+
if self.num_samples is not None:
|
| 44 |
+
questions = questions[: self.num_samples]
|
| 45 |
+
labels = labels[: self.num_samples]
|
| 46 |
+
return questions, labels
|
| 47 |
+
|
| 48 |
+
def create_sgl_function(self):
|
| 49 |
+
"""Create SGL function for MT-Bench (2-turn conversation)."""
|
| 50 |
+
return create_multi_turn_sgl_function(
|
| 51 |
+
function_name="answer_mt_bench",
|
| 52 |
+
system_prompt=SYSTEM_PROMPT,
|
| 53 |
+
num_turns=2,
|
| 54 |
+
max_tokens=self.get_max_new_tokens(),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def get_answer_keys(self) -> List[str]:
|
| 58 |
+
"""Return answer keys for multi-turn conversation."""
|
| 59 |
+
return ["answer_1", "answer_2"]
|
progress/SpecForge/benchmarks/benchmarker/registry.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class BenchmarkRegistry:
|
| 2 |
+
|
| 3 |
+
def __init__(self):
|
| 4 |
+
self.benchmarks = {}
|
| 5 |
+
|
| 6 |
+
def register(self, name: str):
|
| 7 |
+
"""
|
| 8 |
+
Usage:
|
| 9 |
+
```python
|
| 10 |
+
BENCHMARKS = BenchmarkRegistry()
|
| 11 |
+
|
| 12 |
+
BENCHMARKS.register("aime")
|
| 13 |
+
class AIMEBenchmarker(Benchmarker):
|
| 14 |
+
...
|
| 15 |
+
```
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def wrapper(cls):
|
| 19 |
+
self.benchmarks[name] = cls
|
| 20 |
+
return cls
|
| 21 |
+
|
| 22 |
+
return wrapper
|
| 23 |
+
|
| 24 |
+
def get(self, name: str) -> type:
|
| 25 |
+
"""
|
| 26 |
+
Get the benchmark class by name.
|
| 27 |
+
"""
|
| 28 |
+
return self.benchmarks[name]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
BENCHMARKS = BenchmarkRegistry()
|
progress/SpecForge/benchmarks/benchmarker/simpleqa.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
|
| 5 |
+
from .base import Benchmarker
|
| 6 |
+
from .registry import BENCHMARKS
|
| 7 |
+
from .utils import create_simple_sgl_function
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def generate_question(row: Dict[str, Any]) -> str:
|
| 11 |
+
question = row["problem"].strip()
|
| 12 |
+
return question
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@BENCHMARKS.register("simpleqa")
|
| 16 |
+
class SimpleQABenchmarker(Benchmarker):
|
| 17 |
+
"""SimpleQA benchmark implementation."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 20 |
+
super().__init__(num_samples, None)
|
| 21 |
+
|
| 22 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
|
| 23 |
+
# Read data
|
| 24 |
+
ds = load_dataset("basicv8vc/SimpleQA")["test"]
|
| 25 |
+
|
| 26 |
+
questions = []
|
| 27 |
+
labels = []
|
| 28 |
+
for i in range((len(ds))):
|
| 29 |
+
if self.num_samples is not None and i >= self.num_samples:
|
| 30 |
+
break
|
| 31 |
+
|
| 32 |
+
question_text = generate_question(ds[i])
|
| 33 |
+
questions.append({"question": question_text})
|
| 34 |
+
labels.append(None)
|
| 35 |
+
return questions, labels
|
| 36 |
+
|
| 37 |
+
def create_sgl_function(self):
|
| 38 |
+
return create_simple_sgl_function(
|
| 39 |
+
function_name="get_simpleqa_answer",
|
| 40 |
+
answer_key="answer",
|
| 41 |
+
max_tokens=self.get_max_new_tokens(),
|
| 42 |
+
)
|
progress/SpecForge/benchmarks/benchmarker/utils.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for benchmark scripts.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import sglang as sgl
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class BenchmarkMetrics:
|
| 14 |
+
"""Container for benchmark performance metrics."""
|
| 15 |
+
|
| 16 |
+
latency: float
|
| 17 |
+
output_throughput: float
|
| 18 |
+
accept_length: float
|
| 19 |
+
accuracy: Optional[float] = None
|
| 20 |
+
num_questions: int = 0
|
| 21 |
+
num_valid_predictions: int = 0
|
| 22 |
+
categorical_performance: Optional[Dict[str, "BenchmarkMetrics"]] = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def compute_metrics(
|
| 26 |
+
states: List[Any],
|
| 27 |
+
latency: float,
|
| 28 |
+
answer_key: str = "answer",
|
| 29 |
+
additional_answer_keys: Optional[List[str]] = None,
|
| 30 |
+
) -> BenchmarkMetrics:
|
| 31 |
+
"""
|
| 32 |
+
Compute performance metrics from SGLang states.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
states: List of SGLang state objects from run_batch
|
| 36 |
+
latency: Total latency in seconds
|
| 37 |
+
answer_key: Primary key for answer in state meta info
|
| 38 |
+
additional_answer_keys: Additional keys to include in token count (e.g., ["answer_1", "answer_2"])
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
BenchmarkMetrics object with computed metrics
|
| 42 |
+
"""
|
| 43 |
+
# Compute output tokens
|
| 44 |
+
num_output_tokens = 0
|
| 45 |
+
if additional_answer_keys:
|
| 46 |
+
for key in [answer_key] + additional_answer_keys:
|
| 47 |
+
num_output_tokens += sum(
|
| 48 |
+
s.get_meta_info(key)["completion_tokens"] for s in states
|
| 49 |
+
)
|
| 50 |
+
else:
|
| 51 |
+
num_output_tokens = sum(
|
| 52 |
+
s.get_meta_info(answer_key)["completion_tokens"] for s in states
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
output_throughput = num_output_tokens / latency if latency > 0 else 0.0
|
| 56 |
+
|
| 57 |
+
# Compute accept length (speculative decoding metric)
|
| 58 |
+
has_verify = "spec_verify_ct" in states[0].get_meta_info(answer_key)
|
| 59 |
+
if has_verify:
|
| 60 |
+
num_verify_tokens = 0
|
| 61 |
+
if additional_answer_keys:
|
| 62 |
+
for key in [answer_key] + additional_answer_keys:
|
| 63 |
+
num_verify_tokens += sum(
|
| 64 |
+
s.get_meta_info(key).get("spec_verify_ct", 0) for s in states
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
num_verify_tokens = sum(
|
| 68 |
+
s.get_meta_info(answer_key).get("spec_verify_ct", 0) for s in states
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if num_verify_tokens == 0:
|
| 72 |
+
accept_length = 1.0
|
| 73 |
+
else:
|
| 74 |
+
accept_length = num_output_tokens / num_verify_tokens
|
| 75 |
+
else:
|
| 76 |
+
accept_length = 1.0
|
| 77 |
+
|
| 78 |
+
return BenchmarkMetrics(
|
| 79 |
+
latency=latency,
|
| 80 |
+
output_throughput=output_throughput,
|
| 81 |
+
accept_length=accept_length,
|
| 82 |
+
num_questions=len(states),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def print_results(
|
| 87 |
+
metrics_list: List[BenchmarkMetrics],
|
| 88 |
+
benchmark_name: str,
|
| 89 |
+
show_accuracy: bool = False,
|
| 90 |
+
):
|
| 91 |
+
"""
|
| 92 |
+
Print benchmark results in a formatted way.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
metrics_list: List of BenchmarkMetrics from multiple runs
|
| 96 |
+
benchmark_name: Name of the benchmark
|
| 97 |
+
show_accuracy: Whether to show accuracy metrics
|
| 98 |
+
"""
|
| 99 |
+
avg_latency = np.mean([m.latency for m in metrics_list])
|
| 100 |
+
avg_throughput = np.mean([m.output_throughput for m in metrics_list])
|
| 101 |
+
avg_accept_length = np.mean([m.accept_length for m in metrics_list])
|
| 102 |
+
|
| 103 |
+
print(f"\n{'='*50}")
|
| 104 |
+
print(f"{benchmark_name} Evaluation Results")
|
| 105 |
+
print(f"{'='*50}")
|
| 106 |
+
print(f"Number of questions: {metrics_list[0].num_questions}")
|
| 107 |
+
if show_accuracy:
|
| 108 |
+
if metrics_list[0].accuracy is not None:
|
| 109 |
+
avg_accuracy = np.mean(
|
| 110 |
+
[m.accuracy for m in metrics_list if m.accuracy is not None]
|
| 111 |
+
)
|
| 112 |
+
print(f"Average Accuracy: {avg_accuracy:.4f} ({avg_accuracy*100:.2f}%)")
|
| 113 |
+
else:
|
| 114 |
+
print(f"Average Accuracy: None")
|
| 115 |
+
print(f"Average Latency: {avg_latency:.3f} s")
|
| 116 |
+
print(f"Average Output throughput: {avg_throughput:.3f} token/s")
|
| 117 |
+
print(f"Average Accept length: {avg_accept_length:.3f}")
|
| 118 |
+
print(f"{'='*50}\n")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def create_simple_sgl_function(
|
| 122 |
+
function_name: str = "get_answer",
|
| 123 |
+
answer_key: str = "answer",
|
| 124 |
+
system_prompt: Optional[str] = None,
|
| 125 |
+
max_tokens: int = 2048,
|
| 126 |
+
stop: Optional[List[str]] = None,
|
| 127 |
+
user_prefix: Optional[str] = None,
|
| 128 |
+
) -> Callable:
|
| 129 |
+
"""
|
| 130 |
+
Create a simple SGL function for single-turn Q&A.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
function_name: Name of the function
|
| 134 |
+
answer_key: Key for storing the answer
|
| 135 |
+
system_prompt: Optional system prompt
|
| 136 |
+
max_tokens: Maximum tokens to generate
|
| 137 |
+
stop: Optional stop sequences
|
| 138 |
+
user_prefix: Optional suffix to append to user message (appended after question)
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
SGL function decorated with @sgl.function
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
@sgl.function
|
| 145 |
+
def sgl_func(s, question):
|
| 146 |
+
if system_prompt:
|
| 147 |
+
s += sgl.system(system_prompt)
|
| 148 |
+
user_content = question
|
| 149 |
+
if user_prefix:
|
| 150 |
+
user_content = question + user_prefix
|
| 151 |
+
s += sgl.user(user_content)
|
| 152 |
+
gen_kwargs = {"max_tokens": max_tokens}
|
| 153 |
+
if stop:
|
| 154 |
+
gen_kwargs["stop"] = stop
|
| 155 |
+
s += sgl.assistant(sgl.gen(answer_key, **gen_kwargs))
|
| 156 |
+
|
| 157 |
+
sgl_func.__name__ = function_name
|
| 158 |
+
return sgl_func
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def create_few_shot_sgl_function(
|
| 162 |
+
few_shot_examples: str,
|
| 163 |
+
function_name: str = "few_shot_answer",
|
| 164 |
+
answer_key: str = "answer",
|
| 165 |
+
max_tokens: int = 512,
|
| 166 |
+
stop: Optional[List[str]] = None,
|
| 167 |
+
) -> Callable:
|
| 168 |
+
"""
|
| 169 |
+
Create an SGL function for few-shot learning.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
few_shot_examples: String containing few-shot examples
|
| 173 |
+
function_name: Name of the function
|
| 174 |
+
answer_key: Key for storing the answer
|
| 175 |
+
max_tokens: Maximum tokens to generate
|
| 176 |
+
stop: Optional stop sequences
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
SGL function decorated with @sgl.function
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
@sgl.function
|
| 183 |
+
def sgl_func(s, question):
|
| 184 |
+
s += few_shot_examples + question
|
| 185 |
+
gen_kwargs = {"max_tokens": max_tokens}
|
| 186 |
+
if stop:
|
| 187 |
+
gen_kwargs["stop"] = stop
|
| 188 |
+
s += sgl.gen(answer_key, **gen_kwargs)
|
| 189 |
+
|
| 190 |
+
sgl_func.__name__ = function_name
|
| 191 |
+
return sgl_func
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def create_multi_turn_sgl_function(
|
| 195 |
+
function_name: str = "multi_turn_answer",
|
| 196 |
+
system_prompt: Optional[str] = None,
|
| 197 |
+
num_turns: int = 2,
|
| 198 |
+
max_tokens: int = 2048,
|
| 199 |
+
) -> Callable:
|
| 200 |
+
"""
|
| 201 |
+
Create an SGL function for multi-turn conversations (e.g., MT-Bench with 2 turns).
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
function_name: Name of the function
|
| 205 |
+
system_prompt: Optional system prompt
|
| 206 |
+
num_turns: Number of conversation turns (default: 2)
|
| 207 |
+
max_tokens: Maximum tokens to generate per turn
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
SGL function decorated with @sgl.function
|
| 211 |
+
"""
|
| 212 |
+
if num_turns == 2:
|
| 213 |
+
# Most common case: 2-turn conversation
|
| 214 |
+
@sgl.function
|
| 215 |
+
def sgl_func(s, question_1, question_2):
|
| 216 |
+
if system_prompt:
|
| 217 |
+
s += sgl.system(system_prompt)
|
| 218 |
+
s += sgl.user(question_1)
|
| 219 |
+
s += sgl.assistant(sgl.gen("answer_1", max_tokens=max_tokens))
|
| 220 |
+
s += sgl.user(question_2)
|
| 221 |
+
s += sgl.assistant(sgl.gen("answer_2", max_tokens=max_tokens))
|
| 222 |
+
|
| 223 |
+
else:
|
| 224 |
+
# Generic case: create function with dynamic number of turns
|
| 225 |
+
# Note: This requires the caller to pass arguments as a dict
|
| 226 |
+
@sgl.function
|
| 227 |
+
def sgl_func(s, **kwargs):
|
| 228 |
+
if system_prompt:
|
| 229 |
+
s += sgl.system(system_prompt)
|
| 230 |
+
for i in range(num_turns):
|
| 231 |
+
question_key = f"question_{i+1}"
|
| 232 |
+
answer_key = f"answer_{i+1}"
|
| 233 |
+
if question_key in kwargs:
|
| 234 |
+
s += sgl.user(kwargs[question_key])
|
| 235 |
+
s += sgl.assistant(sgl.gen(answer_key, max_tokens=max_tokens))
|
| 236 |
+
|
| 237 |
+
sgl_func.__name__ = function_name
|
| 238 |
+
return sgl_func
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def create_image_sgl_function(
|
| 242 |
+
function_name: str = "get_image_answer",
|
| 243 |
+
answer_key: str = "answer",
|
| 244 |
+
max_tokens: int = 2048,
|
| 245 |
+
) -> Callable:
|
| 246 |
+
"""
|
| 247 |
+
Create an SGL function for image-based Q&A.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
function_name: Name of the function
|
| 251 |
+
answer_key: Key for storing the answer
|
| 252 |
+
max_tokens: Maximum tokens to generate
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
SGL function decorated with @sgl.function
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
@sgl.function
|
| 259 |
+
def sgl_func(s, image_path, question, **kwargs):
|
| 260 |
+
"""
|
| 261 |
+
The body of the SGL function: constructs a multimodal conversation flow.
|
| 262 |
+
|
| 263 |
+
- First, it inputs an image + text question as 'user'.
|
| 264 |
+
- Then, it generates an answer as 'assistant', binding the response to the specified `answer_key`.
|
| 265 |
+
|
| 266 |
+
Note: sgl.image() automatically encodes the image into a format supported by the model for multimodal input.
|
| 267 |
+
"""
|
| 268 |
+
# User input: Image + Text question
|
| 269 |
+
s += sgl.user(sgl.image(image_path) + question)
|
| 270 |
+
s += sgl.assistant(sgl.gen(answer_key, max_tokens=max_tokens))
|
| 271 |
+
|
| 272 |
+
sgl_func.__name__ = function_name
|
| 273 |
+
return sgl_func
|
progress/SpecForge/cache/compiled_kernels/26/c26l7dxpqbfol7d62sqakxdv4rgyh27yhm4hrctevbkw5t6kekia.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 36 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 37 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 38 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 39 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 41 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 42 |
+
Q = arg_Q
|
| 43 |
+
K = arg_K
|
| 44 |
+
V = arg_V
|
| 45 |
+
LSE = arg_LSE
|
| 46 |
+
DELTA = arg_DELTA
|
| 47 |
+
DO = arg_DO
|
| 48 |
+
DQ = arg_DQ
|
| 49 |
+
DV = arg_DV
|
| 50 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 51 |
+
KV_IDX = arg_KV_IDX
|
| 52 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 53 |
+
Q_IDX = arg_Q_IDX
|
| 54 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 55 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 56 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 57 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 58 |
+
|
| 59 |
+
# Sub notation for this kernel:
|
| 60 |
+
#
|
| 61 |
+
# Q: Query, K: Key, V: Value
|
| 62 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 63 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 64 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 65 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 66 |
+
# inductor codegen
|
| 67 |
+
# M: Number of queries, N: Number of keys/values
|
| 68 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 69 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 70 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 71 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 72 |
+
# (Modifiable) Performance tuning options
|
| 73 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 74 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 75 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 76 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 77 |
+
#
|
| 78 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 79 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 80 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 81 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 82 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 83 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 84 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 85 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 86 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 87 |
+
|
| 88 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 89 |
+
# or involve a numerics vs. perf tradeoff
|
| 90 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 91 |
+
# about 20% more numerical error, but slightly faster.
|
| 92 |
+
|
| 93 |
+
# Define strides of inputs
|
| 94 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
|
| 95 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
|
| 96 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
|
| 97 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
|
| 98 |
+
|
| 99 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
|
| 100 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
|
| 101 |
+
|
| 102 |
+
ZQ = 1
|
| 103 |
+
HQ = 32
|
| 104 |
+
HKV = 8
|
| 105 |
+
Q_LEN = ks0
|
| 106 |
+
ZKV = 1
|
| 107 |
+
KV_LEN = ks1
|
| 108 |
+
|
| 109 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 110 |
+
|
| 111 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 112 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 113 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 114 |
+
|
| 115 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 116 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 117 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 118 |
+
|
| 119 |
+
SPARSE_Z = 1
|
| 120 |
+
SPARSE_HQ = 1
|
| 121 |
+
|
| 122 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 123 |
+
|
| 124 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 125 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 126 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 127 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 128 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 129 |
+
|
| 130 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 131 |
+
K += k_adj
|
| 132 |
+
V += v_adj
|
| 133 |
+
DV += dv_adj
|
| 134 |
+
|
| 135 |
+
RCP_LN2 = 1.44269504
|
| 136 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 137 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 138 |
+
|
| 139 |
+
if pid >= NUM_KV_BLOCKS:
|
| 140 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 141 |
+
# THIS BLOCK DOES DQ
|
| 142 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 143 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 144 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 145 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 146 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 147 |
+
stride_kv_num_blks_h = ks2
|
| 148 |
+
stride_kv_idx_h = ks3*ks4
|
| 149 |
+
stride_kv_idx_m = ks4
|
| 150 |
+
|
| 151 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 152 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 153 |
+
|
| 154 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 155 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 156 |
+
|
| 157 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 158 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 159 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 160 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 161 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 162 |
+
|
| 163 |
+
Q2 = Q + q_adj2
|
| 164 |
+
DO2 = DO + do_adj2
|
| 165 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 166 |
+
# if Q is broadcasted)
|
| 167 |
+
DQ2 = DQ + dq_adj2
|
| 168 |
+
LSE2 = LSE + off_chz2
|
| 169 |
+
DELTA2 = DELTA + off_chz2
|
| 170 |
+
|
| 171 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 172 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 173 |
+
|
| 174 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 175 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 176 |
+
|
| 177 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 178 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 179 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 180 |
+
|
| 181 |
+
if PRESCALE_QK:
|
| 182 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 183 |
+
|
| 184 |
+
if IS_DIVISIBLE:
|
| 185 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 186 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 187 |
+
else:
|
| 188 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 189 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 190 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 191 |
+
lse = lse[:, None]
|
| 192 |
+
|
| 193 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 194 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 195 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 196 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 197 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 198 |
+
|
| 199 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 200 |
+
dq = bwd_dq_inner(
|
| 201 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 202 |
+
K, V,
|
| 203 |
+
dq, q, do, Di, lse,
|
| 204 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 205 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 206 |
+
kv_indices, sparse_kv_num_blocks,
|
| 207 |
+
MATMUL_PRECISION,
|
| 208 |
+
IS_FULL_BLOCKS=False,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if HAS_FULL_BLOCKS:
|
| 212 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 213 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 214 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 215 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 216 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 217 |
+
|
| 218 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 219 |
+
dq = bwd_dq_inner(
|
| 220 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 221 |
+
K, V,
|
| 222 |
+
dq, q, do, Di, lse,
|
| 223 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 224 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 225 |
+
kv_indices, sparse_kv_num_blocks,
|
| 226 |
+
MATMUL_PRECISION,
|
| 227 |
+
IS_FULL_BLOCKS=True,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Write back dQ.
|
| 231 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 232 |
+
dq *= SM_SCALE
|
| 233 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 234 |
+
tl.store(dq_ptrs, dq)
|
| 235 |
+
else:
|
| 236 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 237 |
+
else:
|
| 238 |
+
# THIS BLOCK DOES DK & DV
|
| 239 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 240 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 241 |
+
|
| 242 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 243 |
+
|
| 244 |
+
stride_q_num_blks_h = 6
|
| 245 |
+
stride_q_idx_h = 36
|
| 246 |
+
stride_q_idx_n = 6
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 250 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 251 |
+
|
| 252 |
+
start_n1 = pid * BLOCK_N1
|
| 253 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 254 |
+
|
| 255 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 256 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 257 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 258 |
+
|
| 259 |
+
if PRESCALE_QK:
|
| 260 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 261 |
+
|
| 262 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 263 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 264 |
+
|
| 265 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 266 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 267 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 268 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 269 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 270 |
+
|
| 271 |
+
Q1 = Q + q_adj1
|
| 272 |
+
DO1 = DO + do_adj1
|
| 273 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 274 |
+
# if Q is broadcasted)
|
| 275 |
+
LSE1 = LSE + off_chz1
|
| 276 |
+
DELTA1 = DELTA + off_chz1
|
| 277 |
+
|
| 278 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 279 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 280 |
+
|
| 281 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 282 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 283 |
+
|
| 284 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 285 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 286 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 287 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 288 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 289 |
+
|
| 290 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 291 |
+
dk, dv = bwd_dkdv_inner(
|
| 292 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 293 |
+
Q1, DO1, DELTA1, LSE1,
|
| 294 |
+
dk, dv, k, v,
|
| 295 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 296 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 297 |
+
q_indices, sparse_q_num_blocks,
|
| 298 |
+
MATMUL_PRECISION,
|
| 299 |
+
IS_FULL_BLOCKS=False,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if HAS_FULL_BLOCKS:
|
| 304 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 305 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 306 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 307 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 308 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 309 |
+
|
| 310 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 311 |
+
dk, dv = bwd_dkdv_inner(
|
| 312 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 313 |
+
Q1, DO1, DELTA1, LSE1,
|
| 314 |
+
dk, dv, k, v,
|
| 315 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 316 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 317 |
+
q_indices, sparse_q_num_blocks,
|
| 318 |
+
MATMUL_PRECISION,
|
| 319 |
+
IS_FULL_BLOCKS=True,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Write back dV and dK.
|
| 323 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 324 |
+
|
| 325 |
+
index_n = offs_n1[:, None]
|
| 326 |
+
index_k = offs_k[None, :]
|
| 327 |
+
index_v = offs_v[None, :]
|
| 328 |
+
|
| 329 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 330 |
+
tl.store(dv_ptrs, dv)
|
| 331 |
+
else:
|
| 332 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 333 |
+
|
| 334 |
+
dk *= SM_SCALE
|
| 335 |
+
|
| 336 |
+
if SAFE_HEAD_DIM:
|
| 337 |
+
mask = index_n < KV_LEN
|
| 338 |
+
else:
|
| 339 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 340 |
+
|
| 341 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 342 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 343 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 344 |
+
xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
|
| 345 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 346 |
+
|
| 347 |
+
@triton.jit
|
| 348 |
+
def bwd_dq_inner(
|
| 349 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 350 |
+
K, V, # pointers
|
| 351 |
+
dq, q, do, Di, lse,
|
| 352 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 353 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 354 |
+
kv_indices, sparse_kv_num_blocks,
|
| 355 |
+
MATMUL_PRECISION,
|
| 356 |
+
IS_FULL_BLOCKS,
|
| 357 |
+
):
|
| 358 |
+
PRESCALE_QK : tl.constexpr = False
|
| 359 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 360 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 361 |
+
WRITE_DQ : tl.constexpr = True
|
| 362 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 363 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 364 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 365 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 366 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 367 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 368 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 369 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 370 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 371 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 372 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 373 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 374 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 375 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 376 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 377 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 378 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 379 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 380 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 381 |
+
|
| 382 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 383 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 384 |
+
Q_LEN = ks0
|
| 385 |
+
KV_LEN = ks1
|
| 386 |
+
|
| 387 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 388 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 389 |
+
|
| 390 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 391 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 392 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 393 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 394 |
+
|
| 395 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 396 |
+
|
| 397 |
+
for start_n in range(0, hi):
|
| 398 |
+
dq = bwd_dq_block_mn(
|
| 399 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 400 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 401 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 402 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 403 |
+
kv_indices, sparse_kv_num_blocks,
|
| 404 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 405 |
+
IS_FULL_BLOCKS,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Increment pointers.
|
| 409 |
+
offset = get_offset_for_next_block(
|
| 410 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 411 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
kT_ptrs += offset * stride_kn
|
| 415 |
+
vT_ptrs += offset * stride_vn
|
| 416 |
+
|
| 417 |
+
offs_n2 += offset
|
| 418 |
+
|
| 419 |
+
return dq
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
@triton.jit
|
| 423 |
+
def bwd_dq_block_mn(
|
| 424 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 425 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 426 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 427 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 428 |
+
kv_indices, sparse_kv_num_blocks,
|
| 429 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 430 |
+
IS_FULL_BLOCKS,
|
| 431 |
+
):
|
| 432 |
+
PRESCALE_QK : tl.constexpr = False
|
| 433 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 434 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 435 |
+
WRITE_DQ : tl.constexpr = True
|
| 436 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 437 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 438 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 439 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 440 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 441 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 442 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 443 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 444 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 445 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 446 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 447 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 448 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 449 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 450 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 451 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 452 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 453 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 454 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# NB reversed order to since K is transposed
|
| 458 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 459 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 460 |
+
if not PRESCALE_QK:
|
| 461 |
+
qk *= SM_SCALE
|
| 462 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 463 |
+
pre_mod_scores = qk
|
| 464 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 465 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 466 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 467 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 468 |
+
|
| 469 |
+
tmp0 = (qk)
|
| 470 |
+
post_mod_scores = tmp0
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
if not IS_DIVISIBLE:
|
| 476 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 477 |
+
|
| 478 |
+
if not IS_FULL_BLOCKS:
|
| 479 |
+
tmp1 = (m)
|
| 480 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 481 |
+
tmp3 = tmp1 < tmp2
|
| 482 |
+
tmp4 = (n)
|
| 483 |
+
tmp5 = tmp4 <= tmp1
|
| 484 |
+
tmp6 = tmp3 & tmp5
|
| 485 |
+
tmp7 = tmp1 >= tmp2
|
| 486 |
+
tmp8 = tmp4 < tmp2
|
| 487 |
+
tmp9 = tmp7 & tmp8
|
| 488 |
+
tmp10 = tmp8 == 0
|
| 489 |
+
tmp11 = tmp7 & tmp10
|
| 490 |
+
tmp12 = tmp1 - tmp2
|
| 491 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 492 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 493 |
+
tmp15 = tmp4 - tmp2
|
| 494 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 495 |
+
tmp17 = tmp14 == tmp16
|
| 496 |
+
tmp18 = tmp11 & tmp17
|
| 497 |
+
tmp19 = tmp9 | tmp18
|
| 498 |
+
tmp20 = tmp6 | tmp19
|
| 499 |
+
mask_mod_output = tmp20
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# apply mask for partial masked block
|
| 503 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 504 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 505 |
+
if not PRESCALE_QK:
|
| 506 |
+
post_mod_scores *= RCP_LN2
|
| 507 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 508 |
+
# Compute dP and dS.
|
| 509 |
+
# NB reversed order to since V is transposed
|
| 510 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 511 |
+
|
| 512 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 513 |
+
ds = p * (dp - Di[:, None])
|
| 514 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 515 |
+
tmp21 = (ds)
|
| 516 |
+
grad_scores = tmp21
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
if not IS_DIVISIBLE:
|
| 520 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 521 |
+
|
| 522 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 523 |
+
if WRITE_DQ:
|
| 524 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 525 |
+
|
| 526 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 527 |
+
ds = grad_scores
|
| 528 |
+
|
| 529 |
+
if not IS_FULL_BLOCKS:
|
| 530 |
+
# (grads) apply mask for partially unmasked block
|
| 531 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 532 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 533 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 534 |
+
# Compute dQ.
|
| 535 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 536 |
+
|
| 537 |
+
return dq
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@triton.jit
|
| 541 |
+
def bwd_dkdv_inner(
|
| 542 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 543 |
+
Q, DO, DELTA, LSE, # pointers
|
| 544 |
+
dk, dv, k, v,
|
| 545 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 546 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 547 |
+
q_indices, sparse_q_num_blocks,
|
| 548 |
+
MATMUL_PRECISION,
|
| 549 |
+
IS_FULL_BLOCKS,
|
| 550 |
+
):
|
| 551 |
+
PRESCALE_QK : tl.constexpr = False
|
| 552 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 553 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 554 |
+
WRITE_DQ : tl.constexpr = True
|
| 555 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 556 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 557 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 558 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 559 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 560 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 561 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 562 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 563 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 564 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 565 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 566 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 567 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 568 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 569 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 570 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 571 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 572 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 573 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 574 |
+
|
| 575 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 576 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 577 |
+
Q_LEN = ks0
|
| 578 |
+
KV_LEN = ks1
|
| 579 |
+
|
| 580 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 581 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 582 |
+
|
| 583 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 584 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 585 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 586 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 587 |
+
|
| 588 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 589 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 590 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 591 |
+
|
| 592 |
+
for start_m in range(0, hi):
|
| 593 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 594 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 595 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 596 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 597 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 598 |
+
q_indices, sparse_q_num_blocks,
|
| 599 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 600 |
+
IS_FULL_BLOCKS,
|
| 601 |
+
)
|
| 602 |
+
# Increment pointers.
|
| 603 |
+
offset = get_offset_for_next_block(
|
| 604 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 605 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
qT_ptrs += offset * stride_qm
|
| 609 |
+
do_ptrs += offset * stride_dom
|
| 610 |
+
offs_m1 += offset
|
| 611 |
+
|
| 612 |
+
return dk, dv
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
@triton.jit
|
| 616 |
+
def bwd_dkdv_block_mn(
|
| 617 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 618 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 619 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 620 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 621 |
+
q_indices, sparse_q_num_blocks,
|
| 622 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 623 |
+
IS_FULL_BLOCKS,
|
| 624 |
+
):
|
| 625 |
+
PRESCALE_QK : tl.constexpr = False
|
| 626 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 627 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 628 |
+
WRITE_DQ : tl.constexpr = True
|
| 629 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 630 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 631 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 632 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 633 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 634 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 635 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 636 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 637 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 638 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 639 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 640 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 641 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 642 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 643 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 644 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 645 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 646 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 647 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
# NB reversed order since Q is transposed
|
| 651 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 652 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 653 |
+
if IS_DIVISIBLE:
|
| 654 |
+
lse = tl.load(LSE + offs_m1)
|
| 655 |
+
else:
|
| 656 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 657 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 658 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 659 |
+
if not PRESCALE_QK:
|
| 660 |
+
qkT *= SM_SCALE
|
| 661 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 662 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 663 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 664 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 665 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 666 |
+
|
| 667 |
+
pre_mod_scores = qkT
|
| 668 |
+
tmp22 = (qkT)
|
| 669 |
+
post_mod_scores = tmp22
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
if not IS_DIVISIBLE:
|
| 674 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 675 |
+
|
| 676 |
+
if not IS_FULL_BLOCKS:
|
| 677 |
+
tmp23 = (m)
|
| 678 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 679 |
+
tmp25 = tmp23 < tmp24
|
| 680 |
+
tmp26 = (n)
|
| 681 |
+
tmp27 = tmp26 <= tmp23
|
| 682 |
+
tmp28 = tmp25 & tmp27
|
| 683 |
+
tmp29 = tmp23 >= tmp24
|
| 684 |
+
tmp30 = tmp26 < tmp24
|
| 685 |
+
tmp31 = tmp29 & tmp30
|
| 686 |
+
tmp32 = tmp30 == 0
|
| 687 |
+
tmp33 = tmp29 & tmp32
|
| 688 |
+
tmp34 = tmp23 - tmp24
|
| 689 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 690 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 691 |
+
tmp37 = tmp26 - tmp24
|
| 692 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 693 |
+
tmp39 = tmp36 == tmp38
|
| 694 |
+
tmp40 = tmp33 & tmp39
|
| 695 |
+
tmp41 = tmp31 | tmp40
|
| 696 |
+
tmp42 = tmp28 | tmp41
|
| 697 |
+
mask_mod_output = tmp42
|
| 698 |
+
|
| 699 |
+
# (grads) apply mask for fully masked block
|
| 700 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 701 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 702 |
+
if not PRESCALE_QK:
|
| 703 |
+
post_mod_scores *= RCP_LN2
|
| 704 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 705 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 706 |
+
# Compute dV.
|
| 707 |
+
ppT = pT
|
| 708 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 709 |
+
if IS_DIVISIBLE:
|
| 710 |
+
Di = tl.load(DELTA + offs_m1)
|
| 711 |
+
else:
|
| 712 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 713 |
+
# Compute dP and dS.
|
| 714 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 715 |
+
dsT = pT * (dpT - Di[None, :])
|
| 716 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 717 |
+
tmp43 = (dsT)
|
| 718 |
+
grad_scores = tmp43
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
if not IS_DIVISIBLE:
|
| 723 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 724 |
+
|
| 725 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 726 |
+
if not WRITE_DQ:
|
| 727 |
+
idx_b = off_z
|
| 728 |
+
idx_h = off_hq
|
| 729 |
+
idx_m = m
|
| 730 |
+
idx_n = n
|
| 731 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 732 |
+
|
| 733 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 734 |
+
dsT = grad_scores
|
| 735 |
+
if not IS_FULL_BLOCKS:
|
| 736 |
+
# (grads) apply mask for partially unmasked block
|
| 737 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 738 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 739 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 740 |
+
|
| 741 |
+
return dk, dv
|
| 742 |
+
|
| 743 |
+
# Utility triton funcs
|
| 744 |
+
@triton.jit
|
| 745 |
+
def get_offset_for_next_block(
|
| 746 |
+
loop_iter, col_indices, total_blocks,
|
| 747 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 748 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 749 |
+
):
|
| 750 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 751 |
+
return BLOCK
|
| 752 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 753 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 754 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 755 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 756 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 757 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 758 |
+
return offset
|
| 759 |
+
|
| 760 |
+
@triton.jit
|
| 761 |
+
def get_bounded_indices(indices, max_len=None):
|
| 762 |
+
return indices % max_len if max_len is not None else indices
|
| 763 |
+
|
| 764 |
+
@triton.jit
|
| 765 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 766 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 767 |
+
return tl.load(block_ptr)
|
| 768 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 769 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 770 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 771 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 772 |
+
else:
|
| 773 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 774 |
+
|
| 775 |
+
@triton.jit
|
| 776 |
+
def load_checked_2d(
|
| 777 |
+
ptr,
|
| 778 |
+
offs_m,
|
| 779 |
+
offs_n,
|
| 780 |
+
stride_m,
|
| 781 |
+
stride_n,
|
| 782 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 783 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 784 |
+
M_LEN: tl.constexpr,
|
| 785 |
+
N_LEN: tl.constexpr,
|
| 786 |
+
):
|
| 787 |
+
# Calculate final pointer if strides are provided
|
| 788 |
+
if stride_m is not None and stride_n is not None:
|
| 789 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 790 |
+
|
| 791 |
+
# Handle all masking cases
|
| 792 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 793 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 794 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 795 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 796 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 797 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 798 |
+
else: # Both divisible
|
| 799 |
+
return tl.load(ptr)
|
progress/SpecForge/cache/compiled_kernels/2d/c2d4e47kqxxnp6455gvkteqq3r336462zkbitosyeko6znxktn2b.py
ADDED
|
@@ -0,0 +1,879 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['3_inference']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7g/c7gxkvfztxetv7w7i4s7mr7dlsdda3dfgq3f3uijvhozq6ggk4o4.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# flex_attention => flex_attention
|
| 41 |
+
# Graph fragment:
|
| 42 |
+
# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=arg1_1]
|
| 43 |
+
# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=arg3_1]
|
| 44 |
+
# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:3" = PlaceHolder[target=arg5_1]
|
| 45 |
+
# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf0]
|
| 46 |
+
# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1]
|
| 47 |
+
# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=arg9_1]
|
| 48 |
+
# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=arg6_1]
|
| 49 |
+
# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=arg10_1]
|
| 50 |
+
# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=arg11_1]
|
| 51 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 52 |
+
# return %buf2
|
| 53 |
+
triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
|
| 54 |
+
import triton
|
| 55 |
+
import triton.language as tl
|
| 56 |
+
|
| 57 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 58 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 59 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 60 |
+
|
| 61 |
+
@triton_heuristics.template(
|
| 62 |
+
|
| 63 |
+
num_stages=3,
|
| 64 |
+
num_warps=2,
|
| 65 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 66 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}},
|
| 67 |
+
|
| 68 |
+
)
|
| 69 |
+
@triton.jit
|
| 70 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
|
| 71 |
+
PRESCALE_QK : tl.constexpr = False
|
| 72 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 73 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 74 |
+
WRITE_DQ : tl.constexpr = True
|
| 75 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 76 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 77 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 78 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 79 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 80 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 81 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 82 |
+
SPLIT_KV : tl.constexpr = 32
|
| 83 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 84 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 85 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 86 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 87 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 88 |
+
BLOCK_M : tl.constexpr = 512
|
| 89 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 90 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 91 |
+
BLOCK_N : tl.constexpr = 64
|
| 92 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 93 |
+
USE_TMA : tl.constexpr = False
|
| 94 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 95 |
+
Q = arg_Q
|
| 96 |
+
K = arg_K
|
| 97 |
+
V = arg_V
|
| 98 |
+
M = arg_M
|
| 99 |
+
L = arg_L
|
| 100 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 101 |
+
KV_IDX = arg_KV_IDX
|
| 102 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 103 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 104 |
+
|
| 105 |
+
# Sub notation for this kernel:
|
| 106 |
+
# Q: Query, K: Key, V: Value
|
| 107 |
+
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
|
| 108 |
+
# M: Number of queries, N: Number of keys/values
|
| 109 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 110 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 111 |
+
# BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
|
| 112 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
|
| 113 |
+
# (Modifiable) Config options:
|
| 114 |
+
# SPLIT_KV: number of blocks K & V are split into
|
| 115 |
+
# TILE_KV: length of each local KV split
|
| 116 |
+
# BLOCK_M: block size that Q is padded along seqlen dim.
|
| 117 |
+
# BLOCK_N: block size of K & V along N dimension.
|
| 118 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 119 |
+
#
|
| 120 |
+
# change of base out of the loop
|
| 121 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 122 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 123 |
+
# SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
|
| 124 |
+
# SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
|
| 125 |
+
|
| 126 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
|
| 127 |
+
#
|
| 128 |
+
# SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
|
| 129 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 130 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 131 |
+
#
|
| 132 |
+
#
|
| 133 |
+
# Output: ACC output accumulated across local KV split.
|
| 134 |
+
|
| 135 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 136 |
+
|
| 137 |
+
# Define Q Strides
|
| 138 |
+
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1
|
| 139 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 140 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
|
| 141 |
+
stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1
|
| 142 |
+
stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
Z = 1
|
| 146 |
+
ZKV = 1
|
| 147 |
+
HKV = 8
|
| 148 |
+
G: tl.constexpr = GQA_SHARED_HEADS
|
| 149 |
+
HQ = HKV * G
|
| 150 |
+
Q_LEN = ks0
|
| 151 |
+
KV_LEN = ks1
|
| 152 |
+
|
| 153 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 154 |
+
|
| 155 |
+
# Make sure each split is a multiple of BLOCK_N
|
| 156 |
+
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
|
| 157 |
+
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
|
| 158 |
+
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
|
| 159 |
+
|
| 160 |
+
off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV
|
| 161 |
+
off_zkv = off_z % ZKV
|
| 162 |
+
off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV
|
| 163 |
+
off_t = tl.program_id(1).to(INDEX_DTYPE)
|
| 164 |
+
|
| 165 |
+
q_offset = off_z * stride_qz + off_hkv * stride_qh
|
| 166 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 167 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 168 |
+
|
| 169 |
+
K = K + k_offset
|
| 170 |
+
V = V + v_offset
|
| 171 |
+
|
| 172 |
+
SPARSE_Z = 1
|
| 173 |
+
SPARSE_HQ = 1
|
| 174 |
+
|
| 175 |
+
sparse_idx_z = off_z % SPARSE_Z
|
| 176 |
+
sparse_idx_h = off_hkv % SPARSE_HQ
|
| 177 |
+
|
| 178 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 179 |
+
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
| 180 |
+
|
| 181 |
+
# initialize pointer to m and l
|
| 182 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 183 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 184 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 185 |
+
|
| 186 |
+
# initialize offsets
|
| 187 |
+
tl.device_assert(BLOCK_M % G == 0)
|
| 188 |
+
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
|
| 189 |
+
off_g = tl.arange(0, G) # [G]
|
| 190 |
+
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 191 |
+
offs_hq = offs_g + off_hkv * G
|
| 192 |
+
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
|
| 193 |
+
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 194 |
+
offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 195 |
+
offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 196 |
+
|
| 197 |
+
# Get HZ offsets for KV_NUM_BLKS and KV_IDX
|
| 198 |
+
stride_block_z, stride_block_h, stride_block_row = 1, 1, 1
|
| 199 |
+
sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
|
| 200 |
+
stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1
|
| 201 |
+
sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
|
| 202 |
+
|
| 203 |
+
# Calculate KV blocks that belong this CTA.
|
| 204 |
+
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
|
| 205 |
+
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
|
| 206 |
+
|
| 207 |
+
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
|
| 208 |
+
|
| 209 |
+
if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
| 210 |
+
q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
|
| 211 |
+
elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
| 212 |
+
q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
|
| 213 |
+
elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
|
| 214 |
+
q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
|
| 215 |
+
else:
|
| 216 |
+
q = tl.load(Q + q_offset + q_range)
|
| 217 |
+
|
| 218 |
+
q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 222 |
+
# find first kv block we are loading and the number of blocks we are loading
|
| 223 |
+
# Offset the kv_indices tensor by the correct batch and head
|
| 224 |
+
kv_indices = KV_IDX + sparse_idx_hz_offset
|
| 225 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
|
| 226 |
+
MAX_KV_IDX = 1
|
| 227 |
+
indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
|
| 228 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 229 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 230 |
+
# first kv block we're loading
|
| 231 |
+
|
| 232 |
+
# last valid block according to sparse mask
|
| 233 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 234 |
+
|
| 235 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 236 |
+
|
| 237 |
+
desc_k = None
|
| 238 |
+
desc_v = None
|
| 239 |
+
|
| 240 |
+
acc, l_i, m_i = forward_inner(
|
| 241 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 242 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 243 |
+
# accumulatd values
|
| 244 |
+
acc, l_i, m_i,
|
| 245 |
+
#offsets
|
| 246 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 247 |
+
off_n,
|
| 248 |
+
#block sparse data
|
| 249 |
+
kv_indices, kv_num_blocks,
|
| 250 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 251 |
+
MATMUL_PRECISION,
|
| 252 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 253 |
+
IS_FULL_BLOCKS=False,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 258 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 259 |
+
# apply mask_mod to them - only score_mod
|
| 260 |
+
if HAS_FULL_BLOCKS:
|
| 261 |
+
kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
|
| 262 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
|
| 263 |
+
# Assign full block in a reverse order for off_t. Prioritize the last CTA.
|
| 264 |
+
block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
|
| 265 |
+
block_n_end = block_n_start + TILE_KV_MULTIPLE
|
| 266 |
+
indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
|
| 267 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 268 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 269 |
+
|
| 270 |
+
# last valid block according to sparse mask
|
| 271 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 272 |
+
|
| 273 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 274 |
+
|
| 275 |
+
acc, l_i, m_i = forward_inner(
|
| 276 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 277 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 278 |
+
# accumulatd values
|
| 279 |
+
acc, l_i, m_i,
|
| 280 |
+
#offsets
|
| 281 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 282 |
+
off_n,
|
| 283 |
+
#block sparse data
|
| 284 |
+
kv_indices, kv_num_blocks,
|
| 285 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 286 |
+
MATMUL_PRECISION,
|
| 287 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 288 |
+
IS_FULL_BLOCKS=True,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
m_offset = off_t * stride_mt + off_z * stride_mz
|
| 292 |
+
l_offset = off_t * stride_lt + off_z * stride_lz
|
| 293 |
+
|
| 294 |
+
M_block_ptr = tl.make_block_ptr(
|
| 295 |
+
base=M + m_offset,
|
| 296 |
+
shape=(G, Q_LEN), # (G, M)
|
| 297 |
+
strides=(stride_mh, stride_mm),
|
| 298 |
+
offsets=(off_hkv*G, 0),
|
| 299 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 300 |
+
order=(1, 0)
|
| 301 |
+
)
|
| 302 |
+
L_block_ptr = tl.make_block_ptr(
|
| 303 |
+
base=L + l_offset,
|
| 304 |
+
shape=(G, Q_LEN), # (G, M)
|
| 305 |
+
strides=(stride_lh, stride_lm),
|
| 306 |
+
offsets=(off_hkv*G, 0),
|
| 307 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 308 |
+
order=(1, 0)
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
|
| 312 |
+
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
|
| 313 |
+
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
|
| 314 |
+
if SAFE_M_BOUNDARY:
|
| 315 |
+
tl.store(M_block_ptr, m_i)
|
| 316 |
+
tl.store(L_block_ptr, l_i)
|
| 317 |
+
else:
|
| 318 |
+
tl.store(M_block_ptr, m_i, boundary_check=(1,))
|
| 319 |
+
tl.store(L_block_ptr, l_i, boundary_check=(1,))
|
| 320 |
+
|
| 321 |
+
# -- store output
|
| 322 |
+
idx_z = off_z
|
| 323 |
+
idx_t = off_t
|
| 324 |
+
idx_hq = off_hkv*G + off_g[:, None, None]
|
| 325 |
+
idx_m = off_m[None, :, None]
|
| 326 |
+
idx_d = offs_vd[None, None, :]
|
| 327 |
+
|
| 328 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 329 |
+
acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
|
| 330 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0
|
| 331 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# Utility triton funcs
|
| 335 |
+
@triton.jit
|
| 336 |
+
def get_offset_for_next_block(
|
| 337 |
+
loop_iter, col_indices, total_blocks,
|
| 338 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 339 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 340 |
+
):
|
| 341 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 342 |
+
return BLOCK
|
| 343 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 344 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 345 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 346 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 347 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 348 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 349 |
+
return offset
|
| 350 |
+
|
| 351 |
+
@triton.jit
|
| 352 |
+
def get_bounded_indices(indices, max_len=None):
|
| 353 |
+
return indices % max_len if max_len is not None else indices
|
| 354 |
+
|
| 355 |
+
@triton.jit
|
| 356 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 357 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 358 |
+
return tl.load(block_ptr)
|
| 359 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 360 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 361 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 362 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 363 |
+
else:
|
| 364 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 365 |
+
|
| 366 |
+
@triton.jit
|
| 367 |
+
def load_checked_2d(
|
| 368 |
+
ptr,
|
| 369 |
+
offs_m,
|
| 370 |
+
offs_n,
|
| 371 |
+
stride_m,
|
| 372 |
+
stride_n,
|
| 373 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 374 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 375 |
+
M_LEN: tl.constexpr,
|
| 376 |
+
N_LEN: tl.constexpr,
|
| 377 |
+
):
|
| 378 |
+
# Calculate final pointer if strides are provided
|
| 379 |
+
if stride_m is not None and stride_n is not None:
|
| 380 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 381 |
+
|
| 382 |
+
# Handle all masking cases
|
| 383 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 384 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 385 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 386 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 387 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 388 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 389 |
+
else: # Both divisible
|
| 390 |
+
return tl.load(ptr)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# Common Imports
|
| 394 |
+
@triton.jit
|
| 395 |
+
def forward_block_mn(
|
| 396 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 397 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 398 |
+
# accumulated values
|
| 399 |
+
acc, l_i, m_i,
|
| 400 |
+
# Offsets
|
| 401 |
+
off_z, off_h, offs_m, offs_n,
|
| 402 |
+
# Offsets needed for TMA loads
|
| 403 |
+
kv_start,
|
| 404 |
+
kv_offset,
|
| 405 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 406 |
+
# Strides for K and V
|
| 407 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 408 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 409 |
+
|
| 410 |
+
):
|
| 411 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 412 |
+
PRESCALE_QK : tl.constexpr = False
|
| 413 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 414 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 415 |
+
WRITE_DQ : tl.constexpr = True
|
| 416 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 417 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 418 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 419 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 420 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 421 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 422 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 423 |
+
SPLIT_KV : tl.constexpr = 32
|
| 424 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 425 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 426 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 427 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 428 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 429 |
+
BLOCK_M : tl.constexpr = 512
|
| 430 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 431 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 432 |
+
BLOCK_N : tl.constexpr = 64
|
| 433 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 434 |
+
USE_TMA : tl.constexpr = False
|
| 435 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
# -- load k --
|
| 439 |
+
# NB reversed order to since K is transposed
|
| 440 |
+
kv_base_offset = kv_start + kv_offset
|
| 441 |
+
|
| 442 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 443 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 444 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 445 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 446 |
+
|
| 447 |
+
k = tl.trans(k)
|
| 448 |
+
# -- compute qk ---
|
| 449 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 450 |
+
if not PRESCALE_QK:
|
| 451 |
+
qk *= SM_SCALE
|
| 452 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 453 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 454 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 455 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 456 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 457 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 458 |
+
|
| 459 |
+
tmp0 = (qk)
|
| 460 |
+
post_mod_scores = tmp0
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 464 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 465 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 466 |
+
|
| 467 |
+
if not IS_FULL_BLOCKS:
|
| 468 |
+
tmp1 = (m)
|
| 469 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 470 |
+
tmp3 = tmp1 < tmp2
|
| 471 |
+
tmp4 = (n)
|
| 472 |
+
tmp5 = tmp4 <= tmp1
|
| 473 |
+
tmp6 = tmp3 & tmp5
|
| 474 |
+
tmp7 = tmp1 >= tmp2
|
| 475 |
+
tmp8 = tmp4 < tmp2
|
| 476 |
+
tmp9 = tmp7 & tmp8
|
| 477 |
+
tmp10 = tmp8 == 0
|
| 478 |
+
tmp11 = tmp7 & tmp10
|
| 479 |
+
tmp12 = tmp1 - tmp2
|
| 480 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 481 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 482 |
+
tmp15 = tmp4 - tmp2
|
| 483 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 484 |
+
tmp17 = tmp14 == tmp16
|
| 485 |
+
tmp18 = tmp11 & tmp17
|
| 486 |
+
tmp19 = tmp9 | tmp18
|
| 487 |
+
tmp20 = tmp6 | tmp19
|
| 488 |
+
mask_mod_output = tmp20
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 492 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 493 |
+
# apply mask for partially unmasked blocks
|
| 494 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 495 |
+
|
| 496 |
+
if not PRESCALE_QK:
|
| 497 |
+
post_mod_scores *= RCP_LN2
|
| 498 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 499 |
+
|
| 500 |
+
# -- compute scaling constant ---
|
| 501 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 502 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 503 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 504 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 505 |
+
else:
|
| 506 |
+
m_ij_masked = m_ij
|
| 507 |
+
|
| 508 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 509 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 510 |
+
|
| 511 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 512 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 513 |
+
# m_ij
|
| 514 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 515 |
+
# # -- scale and update acc --
|
| 516 |
+
acc = acc * alpha[:, None]
|
| 517 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 518 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 519 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 520 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 521 |
+
|
| 522 |
+
# -- update m_i
|
| 523 |
+
m_i = m_ij
|
| 524 |
+
|
| 525 |
+
return acc, l_i, m_i
|
| 526 |
+
|
| 527 |
+
@triton.jit
|
| 528 |
+
def forward_inner(
|
| 529 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 530 |
+
q, K, V,
|
| 531 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 532 |
+
# accumulated values
|
| 533 |
+
acc, l_i, m_i,
|
| 534 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 535 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 536 |
+
off_z, off_h, offs_m, offs_n,
|
| 537 |
+
# Offsets needed for TMA loads
|
| 538 |
+
kv_start,
|
| 539 |
+
# blocksparse data
|
| 540 |
+
kv_indices, kv_num_blocks,
|
| 541 |
+
# start kv and end kv block
|
| 542 |
+
block_n_start, block_n_end,
|
| 543 |
+
MATMUL_PRECISION,
|
| 544 |
+
# Strides for K and V
|
| 545 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 546 |
+
IS_FULL_BLOCKS,
|
| 547 |
+
):
|
| 548 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 549 |
+
PRESCALE_QK : tl.constexpr = False
|
| 550 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 551 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 552 |
+
WRITE_DQ : tl.constexpr = True
|
| 553 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 554 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 555 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 556 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 557 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 558 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 559 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 560 |
+
SPLIT_KV : tl.constexpr = 32
|
| 561 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 562 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 563 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 564 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 565 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 566 |
+
BLOCK_M : tl.constexpr = 512
|
| 567 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 568 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 569 |
+
BLOCK_N : tl.constexpr = 64
|
| 570 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 571 |
+
USE_TMA : tl.constexpr = False
|
| 572 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 576 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 577 |
+
|
| 578 |
+
if PRESCALE_QK:
|
| 579 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 580 |
+
|
| 581 |
+
kv_offset = 0
|
| 582 |
+
|
| 583 |
+
# loop over k, v and update accumulator until block_n_end
|
| 584 |
+
for start_n in range(block_n_start, block_n_end):
|
| 585 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 586 |
+
if IS_DIVISIBLE:
|
| 587 |
+
acc, l_i, m_i = forward_block_mn(
|
| 588 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 589 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 590 |
+
# accumulated values
|
| 591 |
+
acc, l_i, m_i,
|
| 592 |
+
# Offsets
|
| 593 |
+
off_z, off_h, offs_m, offs_n,
|
| 594 |
+
# Offsets needed for TMA loads
|
| 595 |
+
kv_start,
|
| 596 |
+
kv_offset,
|
| 597 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 598 |
+
# Strides for K and V
|
| 599 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 600 |
+
IS_FULL_BLOCKS,
|
| 601 |
+
)
|
| 602 |
+
else:
|
| 603 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 604 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 605 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 606 |
+
# to the last block because it's faster a lot.
|
| 607 |
+
acc, l_i, m_i = forward_block_mn(
|
| 608 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 609 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 610 |
+
# accumulated values
|
| 611 |
+
acc, l_i, m_i,
|
| 612 |
+
# Offsets
|
| 613 |
+
off_z, off_h, offs_m, offs_n,
|
| 614 |
+
# Offsets needed for TMA loads
|
| 615 |
+
kv_start,
|
| 616 |
+
kv_offset,
|
| 617 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 618 |
+
# Strides for K and V
|
| 619 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 620 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
offset = get_offset_for_next_block(
|
| 626 |
+
start_n, kv_indices, kv_num_blocks,
|
| 627 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
offs_n = offs_n + offset
|
| 631 |
+
kv_offset += offset
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
return acc, l_i, m_i
|
| 635 |
+
''', device_str='cuda')
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6g/c6gb52skvqs7or57vd3zu5um3r5rnmeimd5qam27l5j7uqx7t4ai.py
|
| 639 |
+
# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul]
|
| 640 |
+
# Source node to ATen node mapping:
|
| 641 |
+
# flex_attention => flex_attention
|
| 642 |
+
# lse_scaled => mul_9
|
| 643 |
+
# Graph fragment:
|
| 644 |
+
# %buf3 : Tensor = PlaceHolder[target=buf3]
|
| 645 |
+
# %buf4 : Tensor = PlaceHolder[target=buf4]
|
| 646 |
+
# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf5]
|
| 647 |
+
# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf7]
|
| 648 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 649 |
+
# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
|
| 650 |
+
# return %buf5,%buf7,%mul_9
|
| 651 |
+
triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', '''
|
| 652 |
+
import triton
|
| 653 |
+
import triton.language as tl
|
| 654 |
+
|
| 655 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 656 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 657 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 658 |
+
triton_helpers.set_driver_to_gpu()
|
| 659 |
+
|
| 660 |
+
@triton_heuristics.persistent_reduction(
|
| 661 |
+
size_hints={'x': 4096, 'r0_': 32},
|
| 662 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 663 |
+
filename=__file__,
|
| 664 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
|
| 665 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 666 |
+
)
|
| 667 |
+
@triton.jit
|
| 668 |
+
def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 669 |
+
r0_numel = 32
|
| 670 |
+
R0_BLOCK: tl.constexpr = 32
|
| 671 |
+
rnumel = r0_numel
|
| 672 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 673 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 674 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 675 |
+
xmask = xindex < xnumel
|
| 676 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 677 |
+
r0_offset = 0
|
| 678 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 679 |
+
roffset = r0_offset
|
| 680 |
+
rindex = r0_index
|
| 681 |
+
r0_1 = r0_index
|
| 682 |
+
x0 = xindex
|
| 683 |
+
x2 = (xindex % ks0)
|
| 684 |
+
x3 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 685 |
+
tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 686 |
+
tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 687 |
+
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
|
| 688 |
+
tmp3 = tl.where(xmask, tmp1, float("-inf"))
|
| 689 |
+
tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
|
| 690 |
+
tmp6 = float("-inf")
|
| 691 |
+
tmp7 = tmp4 == tmp6
|
| 692 |
+
tmp8 = tmp0 - tmp4
|
| 693 |
+
tmp9 = 0.0
|
| 694 |
+
tmp10 = tl.where(tmp7, tmp9, tmp8)
|
| 695 |
+
tmp11 = libdevice.exp2(tmp10)
|
| 696 |
+
tmp12 = tmp5 * tmp11
|
| 697 |
+
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
|
| 698 |
+
tmp15 = tl.where(xmask, tmp13, 0)
|
| 699 |
+
tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
|
| 700 |
+
tmp17 = 1.0
|
| 701 |
+
tmp18 = tl.where(tmp7, tmp17, tmp16)
|
| 702 |
+
tmp19 = libdevice.log2(tmp18)
|
| 703 |
+
tmp20 = tmp19 + tmp4
|
| 704 |
+
tmp21 = 0.6931471805599453
|
| 705 |
+
tmp22 = tmp20 * tmp21
|
| 706 |
+
tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
|
| 707 |
+
tl.store(out_ptr0 + (x0), tmp4, xmask)
|
| 708 |
+
tl.store(out_ptr1 + (x0), tmp16, xmask)
|
| 709 |
+
''', device_str='cuda')
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/jt/cjtngjzio5oudkq4n4xggwz5enmgujrff3ktfnon7oykgb7as5tu.py
|
| 713 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 714 |
+
# Source node to ATen node mapping:
|
| 715 |
+
# flex_attention => flex_attention, getitem
|
| 716 |
+
# Graph fragment:
|
| 717 |
+
# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:3" = PlaceHolder[target=buf2]
|
| 718 |
+
# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf5]
|
| 719 |
+
# %buf3 : Tensor = PlaceHolder[target=buf3]
|
| 720 |
+
# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:3" = PlaceHolder[target=buf8]
|
| 721 |
+
# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf7]
|
| 722 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 723 |
+
# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {})
|
| 724 |
+
# return %buf8,%getitem
|
| 725 |
+
triton_per_fused_2 = async_compile.triton('triton_per_fused_2', '''
|
| 726 |
+
import triton
|
| 727 |
+
import triton.language as tl
|
| 728 |
+
|
| 729 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 730 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 731 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 732 |
+
triton_helpers.set_driver_to_gpu()
|
| 733 |
+
|
| 734 |
+
@triton_heuristics.persistent_reduction(
|
| 735 |
+
size_hints={'x': 524288, 'r0_': 32},
|
| 736 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 737 |
+
filename=__file__,
|
| 738 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
|
| 739 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 740 |
+
)
|
| 741 |
+
@triton.jit
|
| 742 |
+
def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 743 |
+
r0_numel = 32
|
| 744 |
+
R0_BLOCK: tl.constexpr = 32
|
| 745 |
+
rnumel = r0_numel
|
| 746 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 747 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 748 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 749 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 750 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 751 |
+
r0_offset = 0
|
| 752 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 753 |
+
roffset = r0_offset
|
| 754 |
+
rindex = r0_index
|
| 755 |
+
r0_2 = r0_index
|
| 756 |
+
x5 = xindex
|
| 757 |
+
x1 = xindex // 128
|
| 758 |
+
x0 = (xindex % 128)
|
| 759 |
+
x3 = ((xindex // 128) % ks0)
|
| 760 |
+
x4 = xindex // ks1
|
| 761 |
+
tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None)
|
| 762 |
+
tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
|
| 763 |
+
tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last')
|
| 764 |
+
tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last')
|
| 765 |
+
tmp2 = float("-inf")
|
| 766 |
+
tmp3 = tmp1 == tmp2
|
| 767 |
+
tmp5 = tmp4 - tmp1
|
| 768 |
+
tmp6 = 0.0
|
| 769 |
+
tmp7 = tl.where(tmp3, tmp6, tmp5)
|
| 770 |
+
tmp8 = libdevice.exp2(tmp7)
|
| 771 |
+
tmp9 = tmp0 * tmp8
|
| 772 |
+
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
|
| 773 |
+
tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
|
| 774 |
+
tmp14 = 1.0
|
| 775 |
+
tmp15 = tl.where(tmp3, tmp14, tmp13)
|
| 776 |
+
tmp16 = (tmp12 / tmp15)
|
| 777 |
+
tmp17 = tmp16.to(tl.float32)
|
| 778 |
+
tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None)
|
| 779 |
+
''', device_str='cuda')
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
async_compile.wait(globals())
|
| 783 |
+
del async_compile
|
| 784 |
+
|
| 785 |
+
class Runner:
|
| 786 |
+
def __init__(self, partitions):
|
| 787 |
+
self.partitions = partitions
|
| 788 |
+
|
| 789 |
+
def recursively_apply_fns(self, fns):
|
| 790 |
+
new_callables = []
|
| 791 |
+
for fn, c in zip(fns, self.partitions):
|
| 792 |
+
new_callables.append(fn(c))
|
| 793 |
+
self.partitions = new_callables
|
| 794 |
+
|
| 795 |
+
def call(self, args):
|
| 796 |
+
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args
|
| 797 |
+
args.clear()
|
| 798 |
+
s50 = arg0_1
|
| 799 |
+
s0 = arg2_1
|
| 800 |
+
s43 = arg4_1
|
| 801 |
+
s37 = arg7_1
|
| 802 |
+
s71 = arg8_1
|
| 803 |
+
assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 804 |
+
assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 805 |
+
assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1))
|
| 806 |
+
assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 807 |
+
assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1))
|
| 808 |
+
assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1))
|
| 809 |
+
assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 810 |
+
assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1))
|
| 811 |
+
assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 812 |
+
assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1))
|
| 813 |
+
assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 814 |
+
with torch.cuda._DeviceGuard(3):
|
| 815 |
+
torch.cuda.set_device(3)
|
| 816 |
+
buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32)
|
| 817 |
+
buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32)
|
| 818 |
+
buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32)
|
| 819 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 820 |
+
stream3 = get_raw_stream(3)
|
| 821 |
+
triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, 8, 32, 1, stream=stream3)
|
| 822 |
+
del arg10_1
|
| 823 |
+
del arg11_1
|
| 824 |
+
del arg1_1
|
| 825 |
+
del arg3_1
|
| 826 |
+
del arg5_1
|
| 827 |
+
del arg6_1
|
| 828 |
+
del arg9_1
|
| 829 |
+
buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32)
|
| 830 |
+
buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 831 |
+
buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
|
| 832 |
+
# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul]
|
| 833 |
+
triton_per_fused_mul_1_xnumel = 32*s37
|
| 834 |
+
stream3 = get_raw_stream(3)
|
| 835 |
+
triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream3)
|
| 836 |
+
del buf1
|
| 837 |
+
ps0 = 128*s37
|
| 838 |
+
buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
|
| 839 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 840 |
+
triton_per_fused_2_xnumel = 4096*s37
|
| 841 |
+
stream3 = get_raw_stream(3)
|
| 842 |
+
triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream3)
|
| 843 |
+
del buf0
|
| 844 |
+
del buf2
|
| 845 |
+
del buf5
|
| 846 |
+
del buf7
|
| 847 |
+
return (buf9, buf10, )
|
| 848 |
+
|
| 849 |
+
runner = Runner(partitions=[])
|
| 850 |
+
call = runner.call
|
| 851 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 855 |
+
from torch._dynamo.testing import rand_strided
|
| 856 |
+
from torch._inductor.utils import print_performance
|
| 857 |
+
arg0_1 = 96
|
| 858 |
+
arg1_1 = rand_strided((1, 32, 96, 128), (393216, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16)
|
| 859 |
+
arg2_1 = 96
|
| 860 |
+
arg3_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16)
|
| 861 |
+
arg4_1 = 96
|
| 862 |
+
arg5_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16)
|
| 863 |
+
arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 864 |
+
arg7_1 = 96
|
| 865 |
+
arg8_1 = 96
|
| 866 |
+
arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 867 |
+
arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 868 |
+
arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 869 |
+
arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 870 |
+
arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 871 |
+
arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 872 |
+
arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 873 |
+
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1])
|
| 874 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
if __name__ == "__main__":
|
| 878 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 879 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/2g/c2gswut4q57fp2ueybipg5qfqiy5coitofujwdnvqdwhr7nbvnyq.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
USE_TMA : tl.constexpr = False
|
| 36 |
+
BLOCK_M : tl.constexpr = 128
|
| 37 |
+
BLOCK_N : tl.constexpr = 64
|
| 38 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 39 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 41 |
+
Q = arg_Q
|
| 42 |
+
K = arg_K
|
| 43 |
+
V = arg_V
|
| 44 |
+
LSE = arg_LSE
|
| 45 |
+
MAX = arg_MAX
|
| 46 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 47 |
+
KV_IDX = arg_KV_IDX
|
| 48 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 49 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 50 |
+
|
| 51 |
+
# Sub notation for this kernel:
|
| 52 |
+
#
|
| 53 |
+
# Q: Query, K: Key, V: Value
|
| 54 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 55 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 56 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 57 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 58 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 59 |
+
#
|
| 60 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 61 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 62 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 63 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 64 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 65 |
+
#
|
| 66 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 67 |
+
#
|
| 68 |
+
# (Modifiable) Performance tuning options
|
| 69 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 70 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 71 |
+
|
| 72 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 73 |
+
# or involve a numerics vs. perf tradeoff
|
| 74 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 75 |
+
# about 20% more numerical error, but slightly faster.
|
| 76 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 77 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 78 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 79 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 80 |
+
|
| 81 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 82 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 83 |
+
|
| 84 |
+
# Define strides of inputs
|
| 85 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
|
| 86 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 87 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
|
| 88 |
+
|
| 89 |
+
ZQ = 1
|
| 90 |
+
HQ = 32
|
| 91 |
+
Q_LEN = ks0
|
| 92 |
+
ZKV = 1
|
| 93 |
+
KV_LEN = ks1
|
| 94 |
+
|
| 95 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 96 |
+
|
| 97 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 98 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 99 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 100 |
+
|
| 101 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 102 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 103 |
+
off_zkv = off_zq % ZKV
|
| 104 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 105 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 106 |
+
|
| 107 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 108 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 109 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 110 |
+
|
| 111 |
+
Q = Q + q_offset
|
| 112 |
+
K = K + k_offset
|
| 113 |
+
V = V + v_offset
|
| 114 |
+
|
| 115 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 116 |
+
desc_q = None
|
| 117 |
+
desc_k = None
|
| 118 |
+
desc_v = None
|
| 119 |
+
|
| 120 |
+
SPARSE_Z = 1
|
| 121 |
+
SPARSE_HQ = 1
|
| 122 |
+
|
| 123 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 124 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 125 |
+
|
| 126 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 127 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 128 |
+
|
| 129 |
+
stride_kv_num_blks_h = 1
|
| 130 |
+
stride_kv_idx_h = 1
|
| 131 |
+
stride_kv_idx_m = 1
|
| 132 |
+
|
| 133 |
+
# initialize pointer to m and l
|
| 134 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 135 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 136 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 137 |
+
|
| 138 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 139 |
+
|
| 140 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 141 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 142 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 143 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 144 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 145 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 146 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 147 |
+
|
| 148 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 149 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 150 |
+
# both score_mod and mask_mod to it
|
| 151 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 152 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 153 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 154 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# K and V pointers will be passed directly to forward_inner
|
| 158 |
+
|
| 159 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
acc, l_i, m_i = forward_inner(
|
| 163 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 164 |
+
q, K, V,
|
| 165 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 166 |
+
acc, l_i, m_i,
|
| 167 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 168 |
+
kv_start,
|
| 169 |
+
kv_indices, kv_num_blocks,
|
| 170 |
+
0, block_n_end,
|
| 171 |
+
MATMUL_PRECISION,
|
| 172 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 173 |
+
IS_FULL_BLOCKS=False,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 177 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 178 |
+
# apply mask_mod to them - only score_mod
|
| 179 |
+
if HAS_FULL_BLOCKS:
|
| 180 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 181 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 182 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 183 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 184 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 185 |
+
# K and V pointers will be passed directly to forward_inner
|
| 186 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 187 |
+
|
| 188 |
+
acc, l_i, m_i = forward_inner(
|
| 189 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 190 |
+
q, K, V,
|
| 191 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 192 |
+
acc, l_i, m_i,
|
| 193 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 194 |
+
kv_start,
|
| 195 |
+
kv_indices, kv_num_blocks,
|
| 196 |
+
0, block_n_end,
|
| 197 |
+
MATMUL_PRECISION,
|
| 198 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 199 |
+
IS_FULL_BLOCKS=True,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# [Note] Handle fully masked out rows:
|
| 204 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 205 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 206 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 207 |
+
|
| 208 |
+
acc = acc / l_i[:, None]
|
| 209 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 210 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 211 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 212 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 213 |
+
|
| 214 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 215 |
+
|
| 216 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 217 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
|
| 218 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 219 |
+
|
| 220 |
+
if OUTPUT_LOGSUMEXP:
|
| 221 |
+
off_hz = off_zq * HQ + off_hq
|
| 222 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 223 |
+
lse = m_i + tl.math.log2(l_i)
|
| 224 |
+
if IS_DIVISIBLE:
|
| 225 |
+
tl.store(l_ptrs, lse)
|
| 226 |
+
else:
|
| 227 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 228 |
+
|
| 229 |
+
if OUTPUT_MAX:
|
| 230 |
+
off_hz = off_zq * HQ + off_hq
|
| 231 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 232 |
+
if IS_DIVISIBLE:
|
| 233 |
+
tl.store(max_ptrs, m_i)
|
| 234 |
+
else:
|
| 235 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Utility triton funcs
|
| 239 |
+
@triton.jit
|
| 240 |
+
def get_offset_for_next_block(
|
| 241 |
+
loop_iter, col_indices, total_blocks,
|
| 242 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 243 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 244 |
+
):
|
| 245 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 246 |
+
return BLOCK
|
| 247 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 248 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 249 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 250 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 251 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 252 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 253 |
+
return offset
|
| 254 |
+
|
| 255 |
+
@triton.jit
|
| 256 |
+
def get_bounded_indices(indices, max_len=None):
|
| 257 |
+
return indices % max_len if max_len is not None else indices
|
| 258 |
+
|
| 259 |
+
@triton.jit
|
| 260 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 261 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 262 |
+
return tl.load(block_ptr)
|
| 263 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 264 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 265 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 266 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 267 |
+
else:
|
| 268 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 269 |
+
|
| 270 |
+
@triton.jit
|
| 271 |
+
def load_checked_2d(
|
| 272 |
+
ptr,
|
| 273 |
+
offs_m,
|
| 274 |
+
offs_n,
|
| 275 |
+
stride_m,
|
| 276 |
+
stride_n,
|
| 277 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 278 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 279 |
+
M_LEN: tl.constexpr,
|
| 280 |
+
N_LEN: tl.constexpr,
|
| 281 |
+
):
|
| 282 |
+
# Calculate final pointer if strides are provided
|
| 283 |
+
if stride_m is not None and stride_n is not None:
|
| 284 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 285 |
+
|
| 286 |
+
# Handle all masking cases
|
| 287 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 288 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 289 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 290 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 291 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 292 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 293 |
+
else: # Both divisible
|
| 294 |
+
return tl.load(ptr)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# Common Imports
|
| 298 |
+
@triton.jit
|
| 299 |
+
def forward_block_mn(
|
| 300 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 301 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 302 |
+
# accumulated values
|
| 303 |
+
acc, l_i, m_i,
|
| 304 |
+
# Offsets
|
| 305 |
+
off_z, off_h, offs_m, offs_n,
|
| 306 |
+
# Offsets needed for TMA loads
|
| 307 |
+
kv_start,
|
| 308 |
+
kv_offset,
|
| 309 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 310 |
+
# Strides for K and V
|
| 311 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 312 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 313 |
+
|
| 314 |
+
):
|
| 315 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 316 |
+
PRESCALE_QK : tl.constexpr = False
|
| 317 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 318 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 319 |
+
WRITE_DQ : tl.constexpr = True
|
| 320 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 321 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 322 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 323 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 324 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 325 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 326 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 327 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 328 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 329 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 330 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 331 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 332 |
+
USE_TMA : tl.constexpr = False
|
| 333 |
+
BLOCK_M : tl.constexpr = 128
|
| 334 |
+
BLOCK_N : tl.constexpr = 64
|
| 335 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 336 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 337 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# -- load k --
|
| 341 |
+
# NB reversed order to since K is transposed
|
| 342 |
+
kv_base_offset = kv_start + kv_offset
|
| 343 |
+
|
| 344 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 345 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 346 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 347 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 348 |
+
|
| 349 |
+
k = tl.trans(k)
|
| 350 |
+
# -- compute qk ---
|
| 351 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 352 |
+
if not PRESCALE_QK:
|
| 353 |
+
qk *= SM_SCALE
|
| 354 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 355 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 356 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 357 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 358 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 359 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 360 |
+
|
| 361 |
+
tmp0 = (qk)
|
| 362 |
+
post_mod_scores = tmp0
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 366 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 367 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 368 |
+
|
| 369 |
+
if not IS_FULL_BLOCKS:
|
| 370 |
+
tmp1 = (m)
|
| 371 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 372 |
+
tmp3 = tmp1 < tmp2
|
| 373 |
+
tmp4 = (n)
|
| 374 |
+
tmp5 = tmp4 <= tmp1
|
| 375 |
+
tmp6 = tmp3 & tmp5
|
| 376 |
+
tmp7 = tmp1 >= tmp2
|
| 377 |
+
tmp8 = tmp4 < tmp2
|
| 378 |
+
tmp9 = tmp7 & tmp8
|
| 379 |
+
tmp10 = tmp8 == 0
|
| 380 |
+
tmp11 = tmp7 & tmp10
|
| 381 |
+
tmp12 = tmp1 - tmp2
|
| 382 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 383 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 384 |
+
tmp15 = tmp4 - tmp2
|
| 385 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 386 |
+
tmp17 = tmp14 == tmp16
|
| 387 |
+
tmp18 = tmp11 & tmp17
|
| 388 |
+
tmp19 = tmp9 | tmp18
|
| 389 |
+
tmp20 = tmp6 | tmp19
|
| 390 |
+
mask_mod_output = tmp20
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 394 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 395 |
+
# apply mask for partially unmasked blocks
|
| 396 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 397 |
+
|
| 398 |
+
if not PRESCALE_QK:
|
| 399 |
+
post_mod_scores *= RCP_LN2
|
| 400 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 401 |
+
|
| 402 |
+
# -- compute scaling constant ---
|
| 403 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 404 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 405 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 406 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 407 |
+
else:
|
| 408 |
+
m_ij_masked = m_ij
|
| 409 |
+
|
| 410 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 411 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 412 |
+
|
| 413 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 414 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 415 |
+
# m_ij
|
| 416 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 417 |
+
# # -- scale and update acc --
|
| 418 |
+
acc = acc * alpha[:, None]
|
| 419 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 420 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 421 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 422 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 423 |
+
|
| 424 |
+
# -- update m_i
|
| 425 |
+
m_i = m_ij
|
| 426 |
+
|
| 427 |
+
return acc, l_i, m_i
|
| 428 |
+
|
| 429 |
+
@triton.jit
|
| 430 |
+
def forward_inner(
|
| 431 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 432 |
+
q, K, V,
|
| 433 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 434 |
+
# accumulated values
|
| 435 |
+
acc, l_i, m_i,
|
| 436 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 437 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 438 |
+
off_z, off_h, offs_m, offs_n,
|
| 439 |
+
# Offsets needed for TMA loads
|
| 440 |
+
kv_start,
|
| 441 |
+
# blocksparse data
|
| 442 |
+
kv_indices, kv_num_blocks,
|
| 443 |
+
# start kv and end kv block
|
| 444 |
+
block_n_start, block_n_end,
|
| 445 |
+
MATMUL_PRECISION,
|
| 446 |
+
# Strides for K and V
|
| 447 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 448 |
+
IS_FULL_BLOCKS,
|
| 449 |
+
):
|
| 450 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 451 |
+
PRESCALE_QK : tl.constexpr = False
|
| 452 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 453 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 454 |
+
WRITE_DQ : tl.constexpr = True
|
| 455 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 456 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 457 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 458 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 459 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 460 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 461 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 462 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 463 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 464 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 465 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 466 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 467 |
+
USE_TMA : tl.constexpr = False
|
| 468 |
+
BLOCK_M : tl.constexpr = 128
|
| 469 |
+
BLOCK_N : tl.constexpr = 64
|
| 470 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 471 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 472 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 476 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 477 |
+
|
| 478 |
+
if PRESCALE_QK:
|
| 479 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 480 |
+
|
| 481 |
+
kv_offset = 0
|
| 482 |
+
|
| 483 |
+
# loop over k, v and update accumulator until block_n_end
|
| 484 |
+
for start_n in range(block_n_start, block_n_end):
|
| 485 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 486 |
+
if IS_DIVISIBLE:
|
| 487 |
+
acc, l_i, m_i = forward_block_mn(
|
| 488 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 489 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 490 |
+
# accumulated values
|
| 491 |
+
acc, l_i, m_i,
|
| 492 |
+
# Offsets
|
| 493 |
+
off_z, off_h, offs_m, offs_n,
|
| 494 |
+
# Offsets needed for TMA loads
|
| 495 |
+
kv_start,
|
| 496 |
+
kv_offset,
|
| 497 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
)
|
| 502 |
+
else:
|
| 503 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 504 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 505 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 506 |
+
# to the last block because it's faster a lot.
|
| 507 |
+
acc, l_i, m_i = forward_block_mn(
|
| 508 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 509 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 510 |
+
# accumulated values
|
| 511 |
+
acc, l_i, m_i,
|
| 512 |
+
# Offsets
|
| 513 |
+
off_z, off_h, offs_m, offs_n,
|
| 514 |
+
# Offsets needed for TMA loads
|
| 515 |
+
kv_start,
|
| 516 |
+
kv_offset,
|
| 517 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 518 |
+
# Strides for K and V
|
| 519 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 520 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
offset = get_offset_for_next_block(
|
| 526 |
+
start_n, kv_indices, kv_num_blocks,
|
| 527 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
offs_n = offs_n + offset
|
| 531 |
+
kv_offset += offset
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
return acc, l_i, m_i
|
progress/SpecForge/cache/compiled_kernels/2j/4b74fa21eaaf86b6290185f6fe50aec9b905d858a087238ceddb52477f3f6acb.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"}
|
progress/SpecForge/cache/compiled_kernels/2j/c2j3mtk3thi6sn2hxiuhuigjw43spiu74mxdervpgpfrtos7u2qh.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 4096},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 20 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 21 |
+
xmask = xindex < xnumel
|
| 22 |
+
x2 = xindex
|
| 23 |
+
x0 = (xindex % ks0)
|
| 24 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 25 |
+
tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
|
| 26 |
+
tmp1 = 0.6931471805599453
|
| 27 |
+
tmp2 = tmp0 * tmp1
|
| 28 |
+
tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
|
progress/SpecForge/cache/compiled_kernels/2n/c2ngvuchx6agpdr6v7awl3qgblaehfzaauoxn6camwvtk7syoxsk.py
ADDED
|
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['4_inference']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6u/c6uror2yjtc6vpcc3on3oq3lwi6yghlxrmwz5rocw5haxvfiz47e.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# flex_attention => flex_attention
|
| 41 |
+
# Graph fragment:
|
| 42 |
+
# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=arg1_1]
|
| 43 |
+
# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg3_1]
|
| 44 |
+
# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg5_1]
|
| 45 |
+
# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=getitem_1]
|
| 46 |
+
# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1]
|
| 47 |
+
# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=arg9_1]
|
| 48 |
+
# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=arg6_1]
|
| 49 |
+
# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=arg10_1]
|
| 50 |
+
# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=arg11_1]
|
| 51 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 52 |
+
# return %getitem
|
| 53 |
+
triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
|
| 54 |
+
import triton
|
| 55 |
+
import triton.language as tl
|
| 56 |
+
|
| 57 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 58 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 59 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 60 |
+
|
| 61 |
+
@triton_heuristics.template(
|
| 62 |
+
|
| 63 |
+
num_stages=3,
|
| 64 |
+
num_warps=8,
|
| 65 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 66 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 67 |
+
|
| 68 |
+
)
|
| 69 |
+
@triton.jit
|
| 70 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
|
| 71 |
+
PRESCALE_QK : tl.constexpr = False
|
| 72 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 73 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 74 |
+
WRITE_DQ : tl.constexpr = True
|
| 75 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 76 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 77 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 78 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 79 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 80 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 81 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 82 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 83 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 84 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 85 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 86 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 87 |
+
USE_TMA : tl.constexpr = False
|
| 88 |
+
BLOCK_M : tl.constexpr = 128
|
| 89 |
+
BLOCK_N : tl.constexpr = 64
|
| 90 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 91 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 92 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 93 |
+
Q = arg_Q
|
| 94 |
+
K = arg_K
|
| 95 |
+
V = arg_V
|
| 96 |
+
LSE = arg_LSE
|
| 97 |
+
MAX = arg_MAX
|
| 98 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 99 |
+
KV_IDX = arg_KV_IDX
|
| 100 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 101 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 102 |
+
|
| 103 |
+
# Sub notation for this kernel:
|
| 104 |
+
#
|
| 105 |
+
# Q: Query, K: Key, V: Value
|
| 106 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 107 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 108 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 109 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 110 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 111 |
+
#
|
| 112 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 113 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 114 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 115 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 116 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 117 |
+
#
|
| 118 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 119 |
+
#
|
| 120 |
+
# (Modifiable) Performance tuning options
|
| 121 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 122 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 123 |
+
|
| 124 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 125 |
+
# or involve a numerics vs. perf tradeoff
|
| 126 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 127 |
+
# about 20% more numerical error, but slightly faster.
|
| 128 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 129 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 130 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 131 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 132 |
+
|
| 133 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 134 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 135 |
+
|
| 136 |
+
# Define strides of inputs
|
| 137 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
|
| 138 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 139 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
|
| 140 |
+
|
| 141 |
+
ZQ = 1
|
| 142 |
+
HQ = 32
|
| 143 |
+
Q_LEN = ks0
|
| 144 |
+
ZKV = 1
|
| 145 |
+
KV_LEN = ks1
|
| 146 |
+
|
| 147 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 148 |
+
|
| 149 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 150 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 151 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 152 |
+
|
| 153 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 154 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 155 |
+
off_zkv = off_zq % ZKV
|
| 156 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 157 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 158 |
+
|
| 159 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 160 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 161 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 162 |
+
|
| 163 |
+
Q = Q + q_offset
|
| 164 |
+
K = K + k_offset
|
| 165 |
+
V = V + v_offset
|
| 166 |
+
|
| 167 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 168 |
+
desc_q = None
|
| 169 |
+
desc_k = None
|
| 170 |
+
desc_v = None
|
| 171 |
+
|
| 172 |
+
SPARSE_Z = 1
|
| 173 |
+
SPARSE_HQ = 1
|
| 174 |
+
|
| 175 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 176 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 177 |
+
|
| 178 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 179 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 180 |
+
|
| 181 |
+
stride_kv_num_blks_h = 1
|
| 182 |
+
stride_kv_idx_h = 1
|
| 183 |
+
stride_kv_idx_m = 1
|
| 184 |
+
|
| 185 |
+
# initialize pointer to m and l
|
| 186 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 187 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 188 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 189 |
+
|
| 190 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 191 |
+
|
| 192 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 193 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 194 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 195 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 196 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 197 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 198 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 199 |
+
|
| 200 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 201 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 202 |
+
# both score_mod and mask_mod to it
|
| 203 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 204 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 205 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 206 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# K and V pointers will be passed directly to forward_inner
|
| 210 |
+
|
| 211 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
acc, l_i, m_i = forward_inner(
|
| 215 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 216 |
+
q, K, V,
|
| 217 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 218 |
+
acc, l_i, m_i,
|
| 219 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 220 |
+
kv_start,
|
| 221 |
+
kv_indices, kv_num_blocks,
|
| 222 |
+
0, block_n_end,
|
| 223 |
+
MATMUL_PRECISION,
|
| 224 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 225 |
+
IS_FULL_BLOCKS=False,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 229 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 230 |
+
# apply mask_mod to them - only score_mod
|
| 231 |
+
if HAS_FULL_BLOCKS:
|
| 232 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 233 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 234 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 235 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 236 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 237 |
+
# K and V pointers will be passed directly to forward_inner
|
| 238 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 239 |
+
|
| 240 |
+
acc, l_i, m_i = forward_inner(
|
| 241 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 242 |
+
q, K, V,
|
| 243 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 244 |
+
acc, l_i, m_i,
|
| 245 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 246 |
+
kv_start,
|
| 247 |
+
kv_indices, kv_num_blocks,
|
| 248 |
+
0, block_n_end,
|
| 249 |
+
MATMUL_PRECISION,
|
| 250 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 251 |
+
IS_FULL_BLOCKS=True,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# [Note] Handle fully masked out rows:
|
| 256 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 257 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 258 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 259 |
+
|
| 260 |
+
acc = acc / l_i[:, None]
|
| 261 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 262 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 263 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 264 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 265 |
+
|
| 266 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 267 |
+
|
| 268 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 269 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
|
| 270 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 271 |
+
|
| 272 |
+
if OUTPUT_LOGSUMEXP:
|
| 273 |
+
off_hz = off_zq * HQ + off_hq
|
| 274 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 275 |
+
lse = m_i + tl.math.log2(l_i)
|
| 276 |
+
if IS_DIVISIBLE:
|
| 277 |
+
tl.store(l_ptrs, lse)
|
| 278 |
+
else:
|
| 279 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 280 |
+
|
| 281 |
+
if OUTPUT_MAX:
|
| 282 |
+
off_hz = off_zq * HQ + off_hq
|
| 283 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 284 |
+
if IS_DIVISIBLE:
|
| 285 |
+
tl.store(max_ptrs, m_i)
|
| 286 |
+
else:
|
| 287 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# Utility triton funcs
|
| 291 |
+
@triton.jit
|
| 292 |
+
def get_offset_for_next_block(
|
| 293 |
+
loop_iter, col_indices, total_blocks,
|
| 294 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 295 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 296 |
+
):
|
| 297 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 298 |
+
return BLOCK
|
| 299 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 300 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 301 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 302 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 303 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 304 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 305 |
+
return offset
|
| 306 |
+
|
| 307 |
+
@triton.jit
|
| 308 |
+
def get_bounded_indices(indices, max_len=None):
|
| 309 |
+
return indices % max_len if max_len is not None else indices
|
| 310 |
+
|
| 311 |
+
@triton.jit
|
| 312 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 313 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 314 |
+
return tl.load(block_ptr)
|
| 315 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 316 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 317 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 318 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 319 |
+
else:
|
| 320 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 321 |
+
|
| 322 |
+
@triton.jit
|
| 323 |
+
def load_checked_2d(
|
| 324 |
+
ptr,
|
| 325 |
+
offs_m,
|
| 326 |
+
offs_n,
|
| 327 |
+
stride_m,
|
| 328 |
+
stride_n,
|
| 329 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 330 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 331 |
+
M_LEN: tl.constexpr,
|
| 332 |
+
N_LEN: tl.constexpr,
|
| 333 |
+
):
|
| 334 |
+
# Calculate final pointer if strides are provided
|
| 335 |
+
if stride_m is not None and stride_n is not None:
|
| 336 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 337 |
+
|
| 338 |
+
# Handle all masking cases
|
| 339 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 340 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 341 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 342 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 343 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 344 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 345 |
+
else: # Both divisible
|
| 346 |
+
return tl.load(ptr)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# Common Imports
|
| 350 |
+
@triton.jit
|
| 351 |
+
def forward_block_mn(
|
| 352 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 353 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 354 |
+
# accumulated values
|
| 355 |
+
acc, l_i, m_i,
|
| 356 |
+
# Offsets
|
| 357 |
+
off_z, off_h, offs_m, offs_n,
|
| 358 |
+
# Offsets needed for TMA loads
|
| 359 |
+
kv_start,
|
| 360 |
+
kv_offset,
|
| 361 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 362 |
+
# Strides for K and V
|
| 363 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 364 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 365 |
+
|
| 366 |
+
):
|
| 367 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 368 |
+
PRESCALE_QK : tl.constexpr = False
|
| 369 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 370 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 371 |
+
WRITE_DQ : tl.constexpr = True
|
| 372 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 373 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 374 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 375 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 376 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 377 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 378 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 379 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 380 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 381 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 382 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 383 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 384 |
+
USE_TMA : tl.constexpr = False
|
| 385 |
+
BLOCK_M : tl.constexpr = 128
|
| 386 |
+
BLOCK_N : tl.constexpr = 64
|
| 387 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 388 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 389 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# -- load k --
|
| 393 |
+
# NB reversed order to since K is transposed
|
| 394 |
+
kv_base_offset = kv_start + kv_offset
|
| 395 |
+
|
| 396 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 397 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 398 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 399 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 400 |
+
|
| 401 |
+
k = tl.trans(k)
|
| 402 |
+
# -- compute qk ---
|
| 403 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 404 |
+
if not PRESCALE_QK:
|
| 405 |
+
qk *= SM_SCALE
|
| 406 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 407 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 408 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 409 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 410 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 411 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 412 |
+
|
| 413 |
+
tmp0 = (qk)
|
| 414 |
+
post_mod_scores = tmp0
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 418 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 419 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 420 |
+
|
| 421 |
+
if not IS_FULL_BLOCKS:
|
| 422 |
+
tmp1 = (m)
|
| 423 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 424 |
+
tmp3 = tmp1 < tmp2
|
| 425 |
+
tmp4 = (n)
|
| 426 |
+
tmp5 = tmp4 <= tmp1
|
| 427 |
+
tmp6 = tmp3 & tmp5
|
| 428 |
+
tmp7 = tmp1 >= tmp2
|
| 429 |
+
tmp8 = tmp4 < tmp2
|
| 430 |
+
tmp9 = tmp7 & tmp8
|
| 431 |
+
tmp10 = tmp8 == 0
|
| 432 |
+
tmp11 = tmp7 & tmp10
|
| 433 |
+
tmp12 = tmp1 - tmp2
|
| 434 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 435 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 436 |
+
tmp15 = tmp4 - tmp2
|
| 437 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 438 |
+
tmp17 = tmp14 == tmp16
|
| 439 |
+
tmp18 = tmp11 & tmp17
|
| 440 |
+
tmp19 = tmp9 | tmp18
|
| 441 |
+
tmp20 = tmp6 | tmp19
|
| 442 |
+
mask_mod_output = tmp20
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 446 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 447 |
+
# apply mask for partially unmasked blocks
|
| 448 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 449 |
+
|
| 450 |
+
if not PRESCALE_QK:
|
| 451 |
+
post_mod_scores *= RCP_LN2
|
| 452 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 453 |
+
|
| 454 |
+
# -- compute scaling constant ---
|
| 455 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 456 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 457 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 458 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 459 |
+
else:
|
| 460 |
+
m_ij_masked = m_ij
|
| 461 |
+
|
| 462 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 463 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 464 |
+
|
| 465 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 466 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 467 |
+
# m_ij
|
| 468 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 469 |
+
# # -- scale and update acc --
|
| 470 |
+
acc = acc * alpha[:, None]
|
| 471 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 472 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 473 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 474 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 475 |
+
|
| 476 |
+
# -- update m_i
|
| 477 |
+
m_i = m_ij
|
| 478 |
+
|
| 479 |
+
return acc, l_i, m_i
|
| 480 |
+
|
| 481 |
+
@triton.jit
|
| 482 |
+
def forward_inner(
|
| 483 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 484 |
+
q, K, V,
|
| 485 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 486 |
+
# accumulated values
|
| 487 |
+
acc, l_i, m_i,
|
| 488 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 489 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 490 |
+
off_z, off_h, offs_m, offs_n,
|
| 491 |
+
# Offsets needed for TMA loads
|
| 492 |
+
kv_start,
|
| 493 |
+
# blocksparse data
|
| 494 |
+
kv_indices, kv_num_blocks,
|
| 495 |
+
# start kv and end kv block
|
| 496 |
+
block_n_start, block_n_end,
|
| 497 |
+
MATMUL_PRECISION,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
):
|
| 502 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 503 |
+
PRESCALE_QK : tl.constexpr = False
|
| 504 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 505 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 506 |
+
WRITE_DQ : tl.constexpr = True
|
| 507 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 508 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 509 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 510 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 511 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 512 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 513 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 514 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 515 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 516 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 517 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 518 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 519 |
+
USE_TMA : tl.constexpr = False
|
| 520 |
+
BLOCK_M : tl.constexpr = 128
|
| 521 |
+
BLOCK_N : tl.constexpr = 64
|
| 522 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 523 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 524 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 528 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 529 |
+
|
| 530 |
+
if PRESCALE_QK:
|
| 531 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 532 |
+
|
| 533 |
+
kv_offset = 0
|
| 534 |
+
|
| 535 |
+
# loop over k, v and update accumulator until block_n_end
|
| 536 |
+
for start_n in range(block_n_start, block_n_end):
|
| 537 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 538 |
+
if IS_DIVISIBLE:
|
| 539 |
+
acc, l_i, m_i = forward_block_mn(
|
| 540 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 541 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 542 |
+
# accumulated values
|
| 543 |
+
acc, l_i, m_i,
|
| 544 |
+
# Offsets
|
| 545 |
+
off_z, off_h, offs_m, offs_n,
|
| 546 |
+
# Offsets needed for TMA loads
|
| 547 |
+
kv_start,
|
| 548 |
+
kv_offset,
|
| 549 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 550 |
+
# Strides for K and V
|
| 551 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 552 |
+
IS_FULL_BLOCKS,
|
| 553 |
+
)
|
| 554 |
+
else:
|
| 555 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 556 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 557 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 558 |
+
# to the last block because it's faster a lot.
|
| 559 |
+
acc, l_i, m_i = forward_block_mn(
|
| 560 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 561 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 562 |
+
# accumulated values
|
| 563 |
+
acc, l_i, m_i,
|
| 564 |
+
# Offsets
|
| 565 |
+
off_z, off_h, offs_m, offs_n,
|
| 566 |
+
# Offsets needed for TMA loads
|
| 567 |
+
kv_start,
|
| 568 |
+
kv_offset,
|
| 569 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 570 |
+
# Strides for K and V
|
| 571 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 572 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
offset = get_offset_for_next_block(
|
| 578 |
+
start_n, kv_indices, kv_num_blocks,
|
| 579 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
offs_n = offs_n + offset
|
| 583 |
+
kv_offset += offset
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
return acc, l_i, m_i
|
| 587 |
+
''', device_str='cuda')
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sc/cscnwzzlpcjsqvndc4tlfwact2ecwdimqtwu2vya2cnto5t7c7pi.py
|
| 591 |
+
# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
|
| 592 |
+
# Source node to ATen node mapping:
|
| 593 |
+
# lse_scaled => mul_9
|
| 594 |
+
# Graph fragment:
|
| 595 |
+
# %buf3 : Tensor = PlaceHolder[target=buf3]
|
| 596 |
+
# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
|
| 597 |
+
# return %mul_9
|
| 598 |
+
triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', '''
|
| 599 |
+
import triton
|
| 600 |
+
import triton.language as tl
|
| 601 |
+
|
| 602 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 603 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 604 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 605 |
+
triton_helpers.set_driver_to_gpu()
|
| 606 |
+
|
| 607 |
+
@triton_heuristics.pointwise(
|
| 608 |
+
size_hints={'x': 4096},
|
| 609 |
+
filename=__file__,
|
| 610 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 611 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 612 |
+
min_elem_per_thread=0
|
| 613 |
+
)
|
| 614 |
+
@triton.jit
|
| 615 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 616 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 617 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 618 |
+
xmask = xindex < xnumel
|
| 619 |
+
x2 = xindex
|
| 620 |
+
x0 = (xindex % ks0)
|
| 621 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 622 |
+
tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
|
| 623 |
+
tmp1 = 0.6931471805599453
|
| 624 |
+
tmp2 = tmp0 * tmp1
|
| 625 |
+
tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
|
| 626 |
+
''', device_str='cuda')
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
async_compile.wait(globals())
|
| 630 |
+
del async_compile
|
| 631 |
+
|
| 632 |
+
class Runner:
|
| 633 |
+
def __init__(self, partitions):
|
| 634 |
+
self.partitions = partitions
|
| 635 |
+
|
| 636 |
+
def recursively_apply_fns(self, fns):
|
| 637 |
+
new_callables = []
|
| 638 |
+
for fn, c in zip(fns, self.partitions):
|
| 639 |
+
new_callables.append(fn(c))
|
| 640 |
+
self.partitions = new_callables
|
| 641 |
+
|
| 642 |
+
def call(self, args):
|
| 643 |
+
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args
|
| 644 |
+
args.clear()
|
| 645 |
+
s50 = arg0_1
|
| 646 |
+
s0 = arg2_1
|
| 647 |
+
s43 = arg4_1
|
| 648 |
+
s37 = arg7_1
|
| 649 |
+
s71 = arg8_1
|
| 650 |
+
assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 651 |
+
assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 652 |
+
assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1))
|
| 653 |
+
assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 654 |
+
assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1))
|
| 655 |
+
assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1))
|
| 656 |
+
assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 657 |
+
assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1))
|
| 658 |
+
assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 659 |
+
assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1))
|
| 660 |
+
assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 661 |
+
with torch.cuda._DeviceGuard(2):
|
| 662 |
+
torch.cuda.set_device(2)
|
| 663 |
+
buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 664 |
+
buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 665 |
+
buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
|
| 666 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 667 |
+
stream2 = get_raw_stream(2)
|
| 668 |
+
triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream2)
|
| 669 |
+
del arg10_1
|
| 670 |
+
del arg11_1
|
| 671 |
+
del arg1_1
|
| 672 |
+
del arg3_1
|
| 673 |
+
del arg5_1
|
| 674 |
+
del arg6_1
|
| 675 |
+
del arg9_1
|
| 676 |
+
del buf1
|
| 677 |
+
buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
|
| 678 |
+
# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
|
| 679 |
+
triton_poi_fused_mul_1_xnumel = 32*s37
|
| 680 |
+
stream2 = get_raw_stream(2)
|
| 681 |
+
triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream2)
|
| 682 |
+
del buf0
|
| 683 |
+
return (buf2, buf5, )
|
| 684 |
+
|
| 685 |
+
runner = Runner(partitions=[])
|
| 686 |
+
call = runner.call
|
| 687 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 691 |
+
from torch._dynamo.testing import rand_strided
|
| 692 |
+
from torch._inductor.utils import print_performance
|
| 693 |
+
arg0_1 = 128
|
| 694 |
+
arg1_1 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16)
|
| 695 |
+
arg2_1 = 128
|
| 696 |
+
arg3_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16)
|
| 697 |
+
arg4_1 = 128
|
| 698 |
+
arg5_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16)
|
| 699 |
+
arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32)
|
| 700 |
+
arg7_1 = 128
|
| 701 |
+
arg8_1 = 128
|
| 702 |
+
arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32)
|
| 703 |
+
arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32)
|
| 704 |
+
arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32)
|
| 705 |
+
arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32)
|
| 706 |
+
arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32)
|
| 707 |
+
arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32)
|
| 708 |
+
arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32)
|
| 709 |
+
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1])
|
| 710 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
if __name__ == "__main__":
|
| 714 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 715 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/2n/c2nooi7ekpz4qvmvghggbegd5cyfspb27jmq2snbi26zbrpoibnx.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.persistent_reduction(
|
| 11 |
+
size_hints={'x': 4096, 'r0_': 128},
|
| 12 |
+
reduction_hint=ReductionHint.INNER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 19 |
+
r0_numel = 128
|
| 20 |
+
R0_BLOCK: tl.constexpr = 128
|
| 21 |
+
rnumel = r0_numel
|
| 22 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 23 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 24 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 25 |
+
xmask = xindex < xnumel
|
| 26 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 27 |
+
r0_offset = 0
|
| 28 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 29 |
+
roffset = r0_offset
|
| 30 |
+
rindex = r0_index
|
| 31 |
+
r0_2 = r0_index
|
| 32 |
+
x0 = (xindex % ks0)
|
| 33 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 34 |
+
x3 = xindex
|
| 35 |
+
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32)
|
| 36 |
+
tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32)
|
| 37 |
+
tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
|
| 38 |
+
tmp2 = tmp0 * tmp1
|
| 39 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 40 |
+
tmp5 = tl.where(xmask, tmp3, 0)
|
| 41 |
+
tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32)
|
| 42 |
+
tmp7 = tmp6.to(tl.float32)
|
| 43 |
+
tmp9 = 0.6931471805599453
|
| 44 |
+
tmp10 = tmp8 * tmp9
|
| 45 |
+
tmp11 = 1.4426950408889634
|
| 46 |
+
tmp12 = tmp10 * tmp11
|
| 47 |
+
tmp13 = tmp7 - tmp12
|
| 48 |
+
tl.store(out_ptr1 + (x3), tmp13, xmask)
|
progress/SpecForge/cache/compiled_kernels/2n/d17ff4e7bb44e5ae89a267ef332bb7c074804ce0942fc0694c3ef15b05f7854a.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 8, "num_warps": 8, "num_stages": 1, "configs_hash": "22b8c9e89632e6687ce26aaad980a76bbf5ee683fff317f3a6d7989c7528ff63", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "WJHIHLPATQZBKSQZSWJ5BD3ABYGFUF3YD6VF633RGCNWMMKVXCCA"}
|
progress/SpecForge/cache/compiled_kernels/2o/c2oashzxz74kzyuwo67tuhk32cike37ysabriftachdv7lf2qxgs.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 36 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 37 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 38 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 39 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 41 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 42 |
+
Q = arg_Q
|
| 43 |
+
K = arg_K
|
| 44 |
+
V = arg_V
|
| 45 |
+
LSE = arg_LSE
|
| 46 |
+
DELTA = arg_DELTA
|
| 47 |
+
DO = arg_DO
|
| 48 |
+
DQ = arg_DQ
|
| 49 |
+
DV = arg_DV
|
| 50 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 51 |
+
KV_IDX = arg_KV_IDX
|
| 52 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 53 |
+
Q_IDX = arg_Q_IDX
|
| 54 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 55 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 56 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 57 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 58 |
+
|
| 59 |
+
# Sub notation for this kernel:
|
| 60 |
+
#
|
| 61 |
+
# Q: Query, K: Key, V: Value
|
| 62 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 63 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 64 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 65 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 66 |
+
# inductor codegen
|
| 67 |
+
# M: Number of queries, N: Number of keys/values
|
| 68 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 69 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 70 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 71 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 72 |
+
# (Modifiable) Performance tuning options
|
| 73 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 74 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 75 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 76 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 77 |
+
#
|
| 78 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 79 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 80 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 81 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 82 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 83 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 84 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 85 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 86 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 87 |
+
|
| 88 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 89 |
+
# or involve a numerics vs. perf tradeoff
|
| 90 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 91 |
+
# about 20% more numerical error, but slightly faster.
|
| 92 |
+
|
| 93 |
+
# Define strides of inputs
|
| 94 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
|
| 95 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
|
| 96 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
|
| 97 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
|
| 98 |
+
|
| 99 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
|
| 100 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
|
| 101 |
+
|
| 102 |
+
ZQ = 1
|
| 103 |
+
HQ = 32
|
| 104 |
+
HKV = 8
|
| 105 |
+
Q_LEN = ks0
|
| 106 |
+
ZKV = 1
|
| 107 |
+
KV_LEN = ks1
|
| 108 |
+
|
| 109 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 110 |
+
|
| 111 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 112 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 113 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 114 |
+
|
| 115 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 116 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 117 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 118 |
+
|
| 119 |
+
SPARSE_Z = 1
|
| 120 |
+
SPARSE_HQ = 1
|
| 121 |
+
|
| 122 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 123 |
+
|
| 124 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 125 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 126 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 127 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 128 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 129 |
+
|
| 130 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 131 |
+
K += k_adj
|
| 132 |
+
V += v_adj
|
| 133 |
+
DV += dv_adj
|
| 134 |
+
|
| 135 |
+
RCP_LN2 = 1.44269504
|
| 136 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 137 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 138 |
+
|
| 139 |
+
if pid >= NUM_KV_BLOCKS:
|
| 140 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 141 |
+
# THIS BLOCK DOES DQ
|
| 142 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 143 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 144 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 145 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 146 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 147 |
+
stride_kv_num_blks_h = ks2
|
| 148 |
+
stride_kv_idx_h = ks3*ks4
|
| 149 |
+
stride_kv_idx_m = ks4
|
| 150 |
+
|
| 151 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 152 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 153 |
+
|
| 154 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 155 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 156 |
+
|
| 157 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 158 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 159 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 160 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 161 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 162 |
+
|
| 163 |
+
Q2 = Q + q_adj2
|
| 164 |
+
DO2 = DO + do_adj2
|
| 165 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 166 |
+
# if Q is broadcasted)
|
| 167 |
+
DQ2 = DQ + dq_adj2
|
| 168 |
+
LSE2 = LSE + off_chz2
|
| 169 |
+
DELTA2 = DELTA + off_chz2
|
| 170 |
+
|
| 171 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 172 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 173 |
+
|
| 174 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 175 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 176 |
+
|
| 177 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 178 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 179 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 180 |
+
|
| 181 |
+
if PRESCALE_QK:
|
| 182 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 183 |
+
|
| 184 |
+
if IS_DIVISIBLE:
|
| 185 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 186 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 187 |
+
else:
|
| 188 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 189 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 190 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 191 |
+
lse = lse[:, None]
|
| 192 |
+
|
| 193 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 194 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 195 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 196 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 197 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 198 |
+
|
| 199 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 200 |
+
dq = bwd_dq_inner(
|
| 201 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 202 |
+
K, V,
|
| 203 |
+
dq, q, do, Di, lse,
|
| 204 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 205 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 206 |
+
kv_indices, sparse_kv_num_blocks,
|
| 207 |
+
MATMUL_PRECISION,
|
| 208 |
+
IS_FULL_BLOCKS=False,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if HAS_FULL_BLOCKS:
|
| 212 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 213 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 214 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 215 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 216 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 217 |
+
|
| 218 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 219 |
+
dq = bwd_dq_inner(
|
| 220 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 221 |
+
K, V,
|
| 222 |
+
dq, q, do, Di, lse,
|
| 223 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 224 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 225 |
+
kv_indices, sparse_kv_num_blocks,
|
| 226 |
+
MATMUL_PRECISION,
|
| 227 |
+
IS_FULL_BLOCKS=True,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Write back dQ.
|
| 231 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 232 |
+
dq *= SM_SCALE
|
| 233 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 234 |
+
tl.store(dq_ptrs, dq)
|
| 235 |
+
else:
|
| 236 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 237 |
+
else:
|
| 238 |
+
# THIS BLOCK DOES DK & DV
|
| 239 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 240 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 241 |
+
|
| 242 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 243 |
+
|
| 244 |
+
stride_q_num_blks_h = ks5
|
| 245 |
+
stride_q_idx_h = ks6*ks7
|
| 246 |
+
stride_q_idx_n = ks6
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 250 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 251 |
+
|
| 252 |
+
start_n1 = pid * BLOCK_N1
|
| 253 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 254 |
+
|
| 255 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 256 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 257 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 258 |
+
|
| 259 |
+
if PRESCALE_QK:
|
| 260 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 261 |
+
|
| 262 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 263 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 264 |
+
|
| 265 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 266 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 267 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 268 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 269 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 270 |
+
|
| 271 |
+
Q1 = Q + q_adj1
|
| 272 |
+
DO1 = DO + do_adj1
|
| 273 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 274 |
+
# if Q is broadcasted)
|
| 275 |
+
LSE1 = LSE + off_chz1
|
| 276 |
+
DELTA1 = DELTA + off_chz1
|
| 277 |
+
|
| 278 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 279 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 280 |
+
|
| 281 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 282 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 283 |
+
|
| 284 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 285 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 286 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 287 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 288 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 289 |
+
|
| 290 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 291 |
+
dk, dv = bwd_dkdv_inner(
|
| 292 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 293 |
+
Q1, DO1, DELTA1, LSE1,
|
| 294 |
+
dk, dv, k, v,
|
| 295 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 296 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 297 |
+
q_indices, sparse_q_num_blocks,
|
| 298 |
+
MATMUL_PRECISION,
|
| 299 |
+
IS_FULL_BLOCKS=False,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if HAS_FULL_BLOCKS:
|
| 304 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 305 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 306 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 307 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 308 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 309 |
+
|
| 310 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 311 |
+
dk, dv = bwd_dkdv_inner(
|
| 312 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 313 |
+
Q1, DO1, DELTA1, LSE1,
|
| 314 |
+
dk, dv, k, v,
|
| 315 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 316 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 317 |
+
q_indices, sparse_q_num_blocks,
|
| 318 |
+
MATMUL_PRECISION,
|
| 319 |
+
IS_FULL_BLOCKS=True,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Write back dV and dK.
|
| 323 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 324 |
+
|
| 325 |
+
index_n = offs_n1[:, None]
|
| 326 |
+
index_k = offs_k[None, :]
|
| 327 |
+
index_v = offs_v[None, :]
|
| 328 |
+
|
| 329 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 330 |
+
tl.store(dv_ptrs, dv)
|
| 331 |
+
else:
|
| 332 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 333 |
+
|
| 334 |
+
dk *= SM_SCALE
|
| 335 |
+
|
| 336 |
+
if SAFE_HEAD_DIM:
|
| 337 |
+
mask = index_n < KV_LEN
|
| 338 |
+
else:
|
| 339 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 340 |
+
|
| 341 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 342 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 343 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 344 |
+
xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
|
| 345 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 346 |
+
|
| 347 |
+
@triton.jit
|
| 348 |
+
def bwd_dq_inner(
|
| 349 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 350 |
+
K, V, # pointers
|
| 351 |
+
dq, q, do, Di, lse,
|
| 352 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 353 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 354 |
+
kv_indices, sparse_kv_num_blocks,
|
| 355 |
+
MATMUL_PRECISION,
|
| 356 |
+
IS_FULL_BLOCKS,
|
| 357 |
+
):
|
| 358 |
+
PRESCALE_QK : tl.constexpr = False
|
| 359 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 360 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 361 |
+
WRITE_DQ : tl.constexpr = True
|
| 362 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 363 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 364 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 365 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 366 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 367 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 368 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 369 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 370 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 371 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 372 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 373 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 374 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 375 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 376 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 377 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 378 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 379 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 380 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 381 |
+
|
| 382 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 383 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 384 |
+
Q_LEN = ks0
|
| 385 |
+
KV_LEN = ks1
|
| 386 |
+
|
| 387 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 388 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 389 |
+
|
| 390 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 391 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 392 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 393 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 394 |
+
|
| 395 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 396 |
+
|
| 397 |
+
for start_n in range(0, hi):
|
| 398 |
+
dq = bwd_dq_block_mn(
|
| 399 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 400 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 401 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 402 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 403 |
+
kv_indices, sparse_kv_num_blocks,
|
| 404 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 405 |
+
IS_FULL_BLOCKS,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Increment pointers.
|
| 409 |
+
offset = get_offset_for_next_block(
|
| 410 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 411 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
kT_ptrs += offset * stride_kn
|
| 415 |
+
vT_ptrs += offset * stride_vn
|
| 416 |
+
|
| 417 |
+
offs_n2 += offset
|
| 418 |
+
|
| 419 |
+
return dq
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
@triton.jit
|
| 423 |
+
def bwd_dq_block_mn(
|
| 424 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 425 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 426 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 427 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 428 |
+
kv_indices, sparse_kv_num_blocks,
|
| 429 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 430 |
+
IS_FULL_BLOCKS,
|
| 431 |
+
):
|
| 432 |
+
PRESCALE_QK : tl.constexpr = False
|
| 433 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 434 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 435 |
+
WRITE_DQ : tl.constexpr = True
|
| 436 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 437 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 438 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 439 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 440 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 441 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 442 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 443 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 444 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 445 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 446 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 447 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 448 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 449 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 450 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 451 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 452 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 453 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 454 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# NB reversed order to since K is transposed
|
| 458 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 459 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 460 |
+
if not PRESCALE_QK:
|
| 461 |
+
qk *= SM_SCALE
|
| 462 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 463 |
+
pre_mod_scores = qk
|
| 464 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 465 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 466 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 467 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 468 |
+
|
| 469 |
+
tmp0 = (qk)
|
| 470 |
+
post_mod_scores = tmp0
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
if not IS_DIVISIBLE:
|
| 476 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 477 |
+
|
| 478 |
+
if not IS_FULL_BLOCKS:
|
| 479 |
+
tmp1 = (m)
|
| 480 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 481 |
+
tmp3 = tmp1 < tmp2
|
| 482 |
+
tmp4 = (n)
|
| 483 |
+
tmp5 = tmp4 <= tmp1
|
| 484 |
+
tmp6 = tmp3 & tmp5
|
| 485 |
+
tmp7 = tmp1 >= tmp2
|
| 486 |
+
tmp8 = tmp4 < tmp2
|
| 487 |
+
tmp9 = tmp7 & tmp8
|
| 488 |
+
tmp10 = tmp8 == 0
|
| 489 |
+
tmp11 = tmp7 & tmp10
|
| 490 |
+
tmp12 = tmp1 - tmp2
|
| 491 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 492 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 493 |
+
tmp15 = tmp4 - tmp2
|
| 494 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 495 |
+
tmp17 = tmp14 == tmp16
|
| 496 |
+
tmp18 = tmp11 & tmp17
|
| 497 |
+
tmp19 = tmp9 | tmp18
|
| 498 |
+
tmp20 = tmp6 | tmp19
|
| 499 |
+
mask_mod_output = tmp20
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# apply mask for partial masked block
|
| 503 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 504 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 505 |
+
if not PRESCALE_QK:
|
| 506 |
+
post_mod_scores *= RCP_LN2
|
| 507 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 508 |
+
# Compute dP and dS.
|
| 509 |
+
# NB reversed order to since V is transposed
|
| 510 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 511 |
+
|
| 512 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 513 |
+
ds = p * (dp - Di[:, None])
|
| 514 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 515 |
+
tmp21 = (ds)
|
| 516 |
+
grad_scores = tmp21
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
if not IS_DIVISIBLE:
|
| 520 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 521 |
+
|
| 522 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 523 |
+
if WRITE_DQ:
|
| 524 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 525 |
+
|
| 526 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 527 |
+
ds = grad_scores
|
| 528 |
+
|
| 529 |
+
if not IS_FULL_BLOCKS:
|
| 530 |
+
# (grads) apply mask for partially unmasked block
|
| 531 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 532 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 533 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 534 |
+
# Compute dQ.
|
| 535 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 536 |
+
|
| 537 |
+
return dq
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@triton.jit
|
| 541 |
+
def bwd_dkdv_inner(
|
| 542 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 543 |
+
Q, DO, DELTA, LSE, # pointers
|
| 544 |
+
dk, dv, k, v,
|
| 545 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 546 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 547 |
+
q_indices, sparse_q_num_blocks,
|
| 548 |
+
MATMUL_PRECISION,
|
| 549 |
+
IS_FULL_BLOCKS,
|
| 550 |
+
):
|
| 551 |
+
PRESCALE_QK : tl.constexpr = False
|
| 552 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 553 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 554 |
+
WRITE_DQ : tl.constexpr = True
|
| 555 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 556 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 557 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 558 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 559 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 560 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 561 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 562 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 563 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 564 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 565 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 566 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 567 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 568 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 569 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 570 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 571 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 572 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 573 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 574 |
+
|
| 575 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 576 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 577 |
+
Q_LEN = ks0
|
| 578 |
+
KV_LEN = ks1
|
| 579 |
+
|
| 580 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 581 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 582 |
+
|
| 583 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 584 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 585 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 586 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 587 |
+
|
| 588 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 589 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 590 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 591 |
+
|
| 592 |
+
for start_m in range(0, hi):
|
| 593 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 594 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 595 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 596 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 597 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 598 |
+
q_indices, sparse_q_num_blocks,
|
| 599 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 600 |
+
IS_FULL_BLOCKS,
|
| 601 |
+
)
|
| 602 |
+
# Increment pointers.
|
| 603 |
+
offset = get_offset_for_next_block(
|
| 604 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 605 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
qT_ptrs += offset * stride_qm
|
| 609 |
+
do_ptrs += offset * stride_dom
|
| 610 |
+
offs_m1 += offset
|
| 611 |
+
|
| 612 |
+
return dk, dv
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
@triton.jit
|
| 616 |
+
def bwd_dkdv_block_mn(
|
| 617 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 618 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 619 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 620 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 621 |
+
q_indices, sparse_q_num_blocks,
|
| 622 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 623 |
+
IS_FULL_BLOCKS,
|
| 624 |
+
):
|
| 625 |
+
PRESCALE_QK : tl.constexpr = False
|
| 626 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 627 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 628 |
+
WRITE_DQ : tl.constexpr = True
|
| 629 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 630 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 631 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 632 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 633 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 634 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 635 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 636 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 637 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 638 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 639 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 640 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 641 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 642 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 643 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 644 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 645 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 646 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 647 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
# NB reversed order since Q is transposed
|
| 651 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 652 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 653 |
+
if IS_DIVISIBLE:
|
| 654 |
+
lse = tl.load(LSE + offs_m1)
|
| 655 |
+
else:
|
| 656 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 657 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 658 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 659 |
+
if not PRESCALE_QK:
|
| 660 |
+
qkT *= SM_SCALE
|
| 661 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 662 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 663 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 664 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 665 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 666 |
+
|
| 667 |
+
pre_mod_scores = qkT
|
| 668 |
+
tmp22 = (qkT)
|
| 669 |
+
post_mod_scores = tmp22
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
if not IS_DIVISIBLE:
|
| 674 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 675 |
+
|
| 676 |
+
if not IS_FULL_BLOCKS:
|
| 677 |
+
tmp23 = (m)
|
| 678 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 679 |
+
tmp25 = tmp23 < tmp24
|
| 680 |
+
tmp26 = (n)
|
| 681 |
+
tmp27 = tmp26 <= tmp23
|
| 682 |
+
tmp28 = tmp25 & tmp27
|
| 683 |
+
tmp29 = tmp23 >= tmp24
|
| 684 |
+
tmp30 = tmp26 < tmp24
|
| 685 |
+
tmp31 = tmp29 & tmp30
|
| 686 |
+
tmp32 = tmp30 == 0
|
| 687 |
+
tmp33 = tmp29 & tmp32
|
| 688 |
+
tmp34 = tmp23 - tmp24
|
| 689 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 690 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 691 |
+
tmp37 = tmp26 - tmp24
|
| 692 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 693 |
+
tmp39 = tmp36 == tmp38
|
| 694 |
+
tmp40 = tmp33 & tmp39
|
| 695 |
+
tmp41 = tmp31 | tmp40
|
| 696 |
+
tmp42 = tmp28 | tmp41
|
| 697 |
+
mask_mod_output = tmp42
|
| 698 |
+
|
| 699 |
+
# (grads) apply mask for fully masked block
|
| 700 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 701 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 702 |
+
if not PRESCALE_QK:
|
| 703 |
+
post_mod_scores *= RCP_LN2
|
| 704 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 705 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 706 |
+
# Compute dV.
|
| 707 |
+
ppT = pT
|
| 708 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 709 |
+
if IS_DIVISIBLE:
|
| 710 |
+
Di = tl.load(DELTA + offs_m1)
|
| 711 |
+
else:
|
| 712 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 713 |
+
# Compute dP and dS.
|
| 714 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 715 |
+
dsT = pT * (dpT - Di[None, :])
|
| 716 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 717 |
+
tmp43 = (dsT)
|
| 718 |
+
grad_scores = tmp43
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
if not IS_DIVISIBLE:
|
| 723 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 724 |
+
|
| 725 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 726 |
+
if not WRITE_DQ:
|
| 727 |
+
idx_b = off_z
|
| 728 |
+
idx_h = off_hq
|
| 729 |
+
idx_m = m
|
| 730 |
+
idx_n = n
|
| 731 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 732 |
+
|
| 733 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 734 |
+
dsT = grad_scores
|
| 735 |
+
if not IS_FULL_BLOCKS:
|
| 736 |
+
# (grads) apply mask for partially unmasked block
|
| 737 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 738 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 739 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 740 |
+
|
| 741 |
+
return dk, dv
|
| 742 |
+
|
| 743 |
+
# Utility triton funcs
|
| 744 |
+
@triton.jit
|
| 745 |
+
def get_offset_for_next_block(
|
| 746 |
+
loop_iter, col_indices, total_blocks,
|
| 747 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 748 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 749 |
+
):
|
| 750 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 751 |
+
return BLOCK
|
| 752 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 753 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 754 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 755 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 756 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 757 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 758 |
+
return offset
|
| 759 |
+
|
| 760 |
+
@triton.jit
|
| 761 |
+
def get_bounded_indices(indices, max_len=None):
|
| 762 |
+
return indices % max_len if max_len is not None else indices
|
| 763 |
+
|
| 764 |
+
@triton.jit
|
| 765 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 766 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 767 |
+
return tl.load(block_ptr)
|
| 768 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 769 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 770 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 771 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 772 |
+
else:
|
| 773 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 774 |
+
|
| 775 |
+
@triton.jit
|
| 776 |
+
def load_checked_2d(
|
| 777 |
+
ptr,
|
| 778 |
+
offs_m,
|
| 779 |
+
offs_n,
|
| 780 |
+
stride_m,
|
| 781 |
+
stride_n,
|
| 782 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 783 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 784 |
+
M_LEN: tl.constexpr,
|
| 785 |
+
N_LEN: tl.constexpr,
|
| 786 |
+
):
|
| 787 |
+
# Calculate final pointer if strides are provided
|
| 788 |
+
if stride_m is not None and stride_n is not None:
|
| 789 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 790 |
+
|
| 791 |
+
# Handle all masking cases
|
| 792 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 793 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 794 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 795 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 796 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 797 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 798 |
+
else: # Both divisible
|
| 799 |
+
return tl.load(ptr)
|
progress/SpecForge/cache/compiled_kernels/2v/c2vob47d7sxpitzmofyr55f5hvxsitxjhpyv5hdiqcdjgbwmxk76.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 36 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 37 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 38 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 39 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 41 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 42 |
+
Q = arg_Q
|
| 43 |
+
K = arg_K
|
| 44 |
+
V = arg_V
|
| 45 |
+
LSE = arg_LSE
|
| 46 |
+
DELTA = arg_DELTA
|
| 47 |
+
DO = arg_DO
|
| 48 |
+
DQ = arg_DQ
|
| 49 |
+
DV = arg_DV
|
| 50 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 51 |
+
KV_IDX = arg_KV_IDX
|
| 52 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 53 |
+
Q_IDX = arg_Q_IDX
|
| 54 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 55 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 56 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 57 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 58 |
+
|
| 59 |
+
# Sub notation for this kernel:
|
| 60 |
+
#
|
| 61 |
+
# Q: Query, K: Key, V: Value
|
| 62 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 63 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 64 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 65 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 66 |
+
# inductor codegen
|
| 67 |
+
# M: Number of queries, N: Number of keys/values
|
| 68 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 69 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 70 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 71 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 72 |
+
# (Modifiable) Performance tuning options
|
| 73 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 74 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 75 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 76 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 77 |
+
#
|
| 78 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 79 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 80 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 81 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 82 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 83 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 84 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 85 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 86 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 87 |
+
|
| 88 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 89 |
+
# or involve a numerics vs. perf tradeoff
|
| 90 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 91 |
+
# about 20% more numerical error, but slightly faster.
|
| 92 |
+
|
| 93 |
+
# Define strides of inputs
|
| 94 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
|
| 95 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
|
| 96 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
|
| 97 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
|
| 98 |
+
|
| 99 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
|
| 100 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
|
| 101 |
+
|
| 102 |
+
ZQ = 1
|
| 103 |
+
HQ = 32
|
| 104 |
+
HKV = 8
|
| 105 |
+
Q_LEN = ks0
|
| 106 |
+
ZKV = 1
|
| 107 |
+
KV_LEN = ks1
|
| 108 |
+
|
| 109 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 110 |
+
|
| 111 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 112 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 113 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 114 |
+
|
| 115 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 116 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 117 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 118 |
+
|
| 119 |
+
SPARSE_Z = 1
|
| 120 |
+
SPARSE_HQ = 1
|
| 121 |
+
|
| 122 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 123 |
+
|
| 124 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 125 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 126 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 127 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 128 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 129 |
+
|
| 130 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 131 |
+
K += k_adj
|
| 132 |
+
V += v_adj
|
| 133 |
+
DV += dv_adj
|
| 134 |
+
|
| 135 |
+
RCP_LN2 = 1.44269504
|
| 136 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 137 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 138 |
+
|
| 139 |
+
if pid >= NUM_KV_BLOCKS:
|
| 140 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 141 |
+
# THIS BLOCK DOES DQ
|
| 142 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 143 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 144 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 145 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 146 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 147 |
+
stride_kv_num_blks_h = 1
|
| 148 |
+
stride_kv_idx_h = 1
|
| 149 |
+
stride_kv_idx_m = 1
|
| 150 |
+
|
| 151 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 152 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 153 |
+
|
| 154 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 155 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 156 |
+
|
| 157 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 158 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 159 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 160 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 161 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 162 |
+
|
| 163 |
+
Q2 = Q + q_adj2
|
| 164 |
+
DO2 = DO + do_adj2
|
| 165 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 166 |
+
# if Q is broadcasted)
|
| 167 |
+
DQ2 = DQ + dq_adj2
|
| 168 |
+
LSE2 = LSE + off_chz2
|
| 169 |
+
DELTA2 = DELTA + off_chz2
|
| 170 |
+
|
| 171 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 172 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 173 |
+
|
| 174 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 175 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 176 |
+
|
| 177 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 178 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 179 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 180 |
+
|
| 181 |
+
if PRESCALE_QK:
|
| 182 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 183 |
+
|
| 184 |
+
if IS_DIVISIBLE:
|
| 185 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 186 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 187 |
+
else:
|
| 188 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 189 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 190 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 191 |
+
lse = lse[:, None]
|
| 192 |
+
|
| 193 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 194 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 195 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 196 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 197 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 198 |
+
|
| 199 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 200 |
+
dq = bwd_dq_inner(
|
| 201 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 202 |
+
K, V,
|
| 203 |
+
dq, q, do, Di, lse,
|
| 204 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 205 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 206 |
+
kv_indices, sparse_kv_num_blocks,
|
| 207 |
+
MATMUL_PRECISION,
|
| 208 |
+
IS_FULL_BLOCKS=False,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if HAS_FULL_BLOCKS:
|
| 212 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 213 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 214 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 215 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 216 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 217 |
+
|
| 218 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 219 |
+
dq = bwd_dq_inner(
|
| 220 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 221 |
+
K, V,
|
| 222 |
+
dq, q, do, Di, lse,
|
| 223 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 224 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 225 |
+
kv_indices, sparse_kv_num_blocks,
|
| 226 |
+
MATMUL_PRECISION,
|
| 227 |
+
IS_FULL_BLOCKS=True,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Write back dQ.
|
| 231 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 232 |
+
dq *= SM_SCALE
|
| 233 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 234 |
+
tl.store(dq_ptrs, dq)
|
| 235 |
+
else:
|
| 236 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 237 |
+
else:
|
| 238 |
+
# THIS BLOCK DOES DK & DV
|
| 239 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 240 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 241 |
+
|
| 242 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 243 |
+
|
| 244 |
+
stride_q_num_blks_h = 1
|
| 245 |
+
stride_q_idx_h = 1
|
| 246 |
+
stride_q_idx_n = 1
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 250 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 251 |
+
|
| 252 |
+
start_n1 = pid * BLOCK_N1
|
| 253 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 254 |
+
|
| 255 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 256 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 257 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 258 |
+
|
| 259 |
+
if PRESCALE_QK:
|
| 260 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 261 |
+
|
| 262 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 263 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 264 |
+
|
| 265 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 266 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 267 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 268 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 269 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 270 |
+
|
| 271 |
+
Q1 = Q + q_adj1
|
| 272 |
+
DO1 = DO + do_adj1
|
| 273 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 274 |
+
# if Q is broadcasted)
|
| 275 |
+
LSE1 = LSE + off_chz1
|
| 276 |
+
DELTA1 = DELTA + off_chz1
|
| 277 |
+
|
| 278 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 279 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 280 |
+
|
| 281 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 282 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 283 |
+
|
| 284 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 285 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 286 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 287 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 288 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 289 |
+
|
| 290 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 291 |
+
dk, dv = bwd_dkdv_inner(
|
| 292 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 293 |
+
Q1, DO1, DELTA1, LSE1,
|
| 294 |
+
dk, dv, k, v,
|
| 295 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 296 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 297 |
+
q_indices, sparse_q_num_blocks,
|
| 298 |
+
MATMUL_PRECISION,
|
| 299 |
+
IS_FULL_BLOCKS=False,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if HAS_FULL_BLOCKS:
|
| 304 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 305 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 306 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 307 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 308 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 309 |
+
|
| 310 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 311 |
+
dk, dv = bwd_dkdv_inner(
|
| 312 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 313 |
+
Q1, DO1, DELTA1, LSE1,
|
| 314 |
+
dk, dv, k, v,
|
| 315 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 316 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 317 |
+
q_indices, sparse_q_num_blocks,
|
| 318 |
+
MATMUL_PRECISION,
|
| 319 |
+
IS_FULL_BLOCKS=True,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Write back dV and dK.
|
| 323 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 324 |
+
|
| 325 |
+
index_n = offs_n1[:, None]
|
| 326 |
+
index_k = offs_k[None, :]
|
| 327 |
+
index_v = offs_v[None, :]
|
| 328 |
+
|
| 329 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 330 |
+
tl.store(dv_ptrs, dv)
|
| 331 |
+
else:
|
| 332 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 333 |
+
|
| 334 |
+
dk *= SM_SCALE
|
| 335 |
+
|
| 336 |
+
if SAFE_HEAD_DIM:
|
| 337 |
+
mask = index_n < KV_LEN
|
| 338 |
+
else:
|
| 339 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 340 |
+
|
| 341 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 342 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 343 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 344 |
+
xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
|
| 345 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 346 |
+
|
| 347 |
+
@triton.jit
|
| 348 |
+
def bwd_dq_inner(
|
| 349 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 350 |
+
K, V, # pointers
|
| 351 |
+
dq, q, do, Di, lse,
|
| 352 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 353 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 354 |
+
kv_indices, sparse_kv_num_blocks,
|
| 355 |
+
MATMUL_PRECISION,
|
| 356 |
+
IS_FULL_BLOCKS,
|
| 357 |
+
):
|
| 358 |
+
PRESCALE_QK : tl.constexpr = False
|
| 359 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 360 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 361 |
+
WRITE_DQ : tl.constexpr = True
|
| 362 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 363 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 364 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 365 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 366 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 367 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 368 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 369 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 370 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 371 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 372 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 373 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 374 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 375 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 376 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 377 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 378 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 379 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 380 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 381 |
+
|
| 382 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 383 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 384 |
+
Q_LEN = ks0
|
| 385 |
+
KV_LEN = ks1
|
| 386 |
+
|
| 387 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 388 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 389 |
+
|
| 390 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 391 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 392 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 393 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 394 |
+
|
| 395 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 396 |
+
|
| 397 |
+
for start_n in range(0, hi):
|
| 398 |
+
dq = bwd_dq_block_mn(
|
| 399 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 400 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 401 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 402 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 403 |
+
kv_indices, sparse_kv_num_blocks,
|
| 404 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 405 |
+
IS_FULL_BLOCKS,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Increment pointers.
|
| 409 |
+
offset = get_offset_for_next_block(
|
| 410 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 411 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
kT_ptrs += offset * stride_kn
|
| 415 |
+
vT_ptrs += offset * stride_vn
|
| 416 |
+
|
| 417 |
+
offs_n2 += offset
|
| 418 |
+
|
| 419 |
+
return dq
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
@triton.jit
|
| 423 |
+
def bwd_dq_block_mn(
|
| 424 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 425 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 426 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 427 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 428 |
+
kv_indices, sparse_kv_num_blocks,
|
| 429 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 430 |
+
IS_FULL_BLOCKS,
|
| 431 |
+
):
|
| 432 |
+
PRESCALE_QK : tl.constexpr = False
|
| 433 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 434 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 435 |
+
WRITE_DQ : tl.constexpr = True
|
| 436 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 437 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 438 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 439 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 440 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 441 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 442 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 443 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 444 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 445 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 446 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 447 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 448 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 449 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 450 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 451 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 452 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 453 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 454 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# NB reversed order to since K is transposed
|
| 458 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 459 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 460 |
+
if not PRESCALE_QK:
|
| 461 |
+
qk *= SM_SCALE
|
| 462 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 463 |
+
pre_mod_scores = qk
|
| 464 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 465 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 466 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 467 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 468 |
+
|
| 469 |
+
tmp0 = (qk)
|
| 470 |
+
post_mod_scores = tmp0
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
if not IS_DIVISIBLE:
|
| 476 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 477 |
+
|
| 478 |
+
if not IS_FULL_BLOCKS:
|
| 479 |
+
tmp1 = (m)
|
| 480 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 481 |
+
tmp3 = tmp1 < tmp2
|
| 482 |
+
tmp4 = (n)
|
| 483 |
+
tmp5 = tmp4 <= tmp1
|
| 484 |
+
tmp6 = tmp3 & tmp5
|
| 485 |
+
tmp7 = tmp1 >= tmp2
|
| 486 |
+
tmp8 = tmp4 < tmp2
|
| 487 |
+
tmp9 = tmp7 & tmp8
|
| 488 |
+
tmp10 = tmp8 == 0
|
| 489 |
+
tmp11 = tmp7 & tmp10
|
| 490 |
+
tmp12 = tmp1 - tmp2
|
| 491 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 492 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 493 |
+
tmp15 = tmp4 - tmp2
|
| 494 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 495 |
+
tmp17 = tmp14 == tmp16
|
| 496 |
+
tmp18 = tmp11 & tmp17
|
| 497 |
+
tmp19 = tmp9 | tmp18
|
| 498 |
+
tmp20 = tmp6 | tmp19
|
| 499 |
+
mask_mod_output = tmp20
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# apply mask for partial masked block
|
| 503 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 504 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 505 |
+
if not PRESCALE_QK:
|
| 506 |
+
post_mod_scores *= RCP_LN2
|
| 507 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 508 |
+
# Compute dP and dS.
|
| 509 |
+
# NB reversed order to since V is transposed
|
| 510 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 511 |
+
|
| 512 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 513 |
+
ds = p * (dp - Di[:, None])
|
| 514 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 515 |
+
tmp21 = (ds)
|
| 516 |
+
grad_scores = tmp21
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
if not IS_DIVISIBLE:
|
| 520 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 521 |
+
|
| 522 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 523 |
+
if WRITE_DQ:
|
| 524 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 525 |
+
|
| 526 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 527 |
+
ds = grad_scores
|
| 528 |
+
|
| 529 |
+
if not IS_FULL_BLOCKS:
|
| 530 |
+
# (grads) apply mask for partially unmasked block
|
| 531 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 532 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 533 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 534 |
+
# Compute dQ.
|
| 535 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 536 |
+
|
| 537 |
+
return dq
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@triton.jit
|
| 541 |
+
def bwd_dkdv_inner(
|
| 542 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 543 |
+
Q, DO, DELTA, LSE, # pointers
|
| 544 |
+
dk, dv, k, v,
|
| 545 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 546 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 547 |
+
q_indices, sparse_q_num_blocks,
|
| 548 |
+
MATMUL_PRECISION,
|
| 549 |
+
IS_FULL_BLOCKS,
|
| 550 |
+
):
|
| 551 |
+
PRESCALE_QK : tl.constexpr = False
|
| 552 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 553 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 554 |
+
WRITE_DQ : tl.constexpr = True
|
| 555 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 556 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 557 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 558 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 559 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 560 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 561 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 562 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 563 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 564 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 565 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 566 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 567 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 568 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 569 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 570 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 571 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 572 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 573 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 574 |
+
|
| 575 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 576 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 577 |
+
Q_LEN = ks0
|
| 578 |
+
KV_LEN = ks1
|
| 579 |
+
|
| 580 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 581 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 582 |
+
|
| 583 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 584 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 585 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 586 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 587 |
+
|
| 588 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 589 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 590 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 591 |
+
|
| 592 |
+
for start_m in range(0, hi):
|
| 593 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 594 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 595 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 596 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 597 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 598 |
+
q_indices, sparse_q_num_blocks,
|
| 599 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 600 |
+
IS_FULL_BLOCKS,
|
| 601 |
+
)
|
| 602 |
+
# Increment pointers.
|
| 603 |
+
offset = get_offset_for_next_block(
|
| 604 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 605 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
qT_ptrs += offset * stride_qm
|
| 609 |
+
do_ptrs += offset * stride_dom
|
| 610 |
+
offs_m1 += offset
|
| 611 |
+
|
| 612 |
+
return dk, dv
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
@triton.jit
|
| 616 |
+
def bwd_dkdv_block_mn(
|
| 617 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 618 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 619 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 620 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 621 |
+
q_indices, sparse_q_num_blocks,
|
| 622 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 623 |
+
IS_FULL_BLOCKS,
|
| 624 |
+
):
|
| 625 |
+
PRESCALE_QK : tl.constexpr = False
|
| 626 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 627 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 628 |
+
WRITE_DQ : tl.constexpr = True
|
| 629 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 630 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 631 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 632 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 633 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 634 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 635 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 636 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 637 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 638 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 639 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 640 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 641 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 642 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 643 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 644 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 645 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 646 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 647 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
# NB reversed order since Q is transposed
|
| 651 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 652 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 653 |
+
if IS_DIVISIBLE:
|
| 654 |
+
lse = tl.load(LSE + offs_m1)
|
| 655 |
+
else:
|
| 656 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 657 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 658 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 659 |
+
if not PRESCALE_QK:
|
| 660 |
+
qkT *= SM_SCALE
|
| 661 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 662 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 663 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 664 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 665 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 666 |
+
|
| 667 |
+
pre_mod_scores = qkT
|
| 668 |
+
tmp22 = (qkT)
|
| 669 |
+
post_mod_scores = tmp22
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
if not IS_DIVISIBLE:
|
| 674 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 675 |
+
|
| 676 |
+
if not IS_FULL_BLOCKS:
|
| 677 |
+
tmp23 = (m)
|
| 678 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 679 |
+
tmp25 = tmp23 < tmp24
|
| 680 |
+
tmp26 = (n)
|
| 681 |
+
tmp27 = tmp26 <= tmp23
|
| 682 |
+
tmp28 = tmp25 & tmp27
|
| 683 |
+
tmp29 = tmp23 >= tmp24
|
| 684 |
+
tmp30 = tmp26 < tmp24
|
| 685 |
+
tmp31 = tmp29 & tmp30
|
| 686 |
+
tmp32 = tmp30 == 0
|
| 687 |
+
tmp33 = tmp29 & tmp32
|
| 688 |
+
tmp34 = tmp23 - tmp24
|
| 689 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 690 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 691 |
+
tmp37 = tmp26 - tmp24
|
| 692 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 693 |
+
tmp39 = tmp36 == tmp38
|
| 694 |
+
tmp40 = tmp33 & tmp39
|
| 695 |
+
tmp41 = tmp31 | tmp40
|
| 696 |
+
tmp42 = tmp28 | tmp41
|
| 697 |
+
mask_mod_output = tmp42
|
| 698 |
+
|
| 699 |
+
# (grads) apply mask for fully masked block
|
| 700 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 701 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 702 |
+
if not PRESCALE_QK:
|
| 703 |
+
post_mod_scores *= RCP_LN2
|
| 704 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 705 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 706 |
+
# Compute dV.
|
| 707 |
+
ppT = pT
|
| 708 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 709 |
+
if IS_DIVISIBLE:
|
| 710 |
+
Di = tl.load(DELTA + offs_m1)
|
| 711 |
+
else:
|
| 712 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 713 |
+
# Compute dP and dS.
|
| 714 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 715 |
+
dsT = pT * (dpT - Di[None, :])
|
| 716 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 717 |
+
tmp43 = (dsT)
|
| 718 |
+
grad_scores = tmp43
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
if not IS_DIVISIBLE:
|
| 723 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 724 |
+
|
| 725 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 726 |
+
if not WRITE_DQ:
|
| 727 |
+
idx_b = off_z
|
| 728 |
+
idx_h = off_hq
|
| 729 |
+
idx_m = m
|
| 730 |
+
idx_n = n
|
| 731 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 732 |
+
|
| 733 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 734 |
+
dsT = grad_scores
|
| 735 |
+
if not IS_FULL_BLOCKS:
|
| 736 |
+
# (grads) apply mask for partially unmasked block
|
| 737 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 738 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 739 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 740 |
+
|
| 741 |
+
return dk, dv
|
| 742 |
+
|
| 743 |
+
# Utility triton funcs
|
| 744 |
+
@triton.jit
|
| 745 |
+
def get_offset_for_next_block(
|
| 746 |
+
loop_iter, col_indices, total_blocks,
|
| 747 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 748 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 749 |
+
):
|
| 750 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 751 |
+
return BLOCK
|
| 752 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 753 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 754 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 755 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 756 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 757 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 758 |
+
return offset
|
| 759 |
+
|
| 760 |
+
@triton.jit
|
| 761 |
+
def get_bounded_indices(indices, max_len=None):
|
| 762 |
+
return indices % max_len if max_len is not None else indices
|
| 763 |
+
|
| 764 |
+
@triton.jit
|
| 765 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 766 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 767 |
+
return tl.load(block_ptr)
|
| 768 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 769 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 770 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 771 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 772 |
+
else:
|
| 773 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 774 |
+
|
| 775 |
+
@triton.jit
|
| 776 |
+
def load_checked_2d(
|
| 777 |
+
ptr,
|
| 778 |
+
offs_m,
|
| 779 |
+
offs_n,
|
| 780 |
+
stride_m,
|
| 781 |
+
stride_n,
|
| 782 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 783 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 784 |
+
M_LEN: tl.constexpr,
|
| 785 |
+
N_LEN: tl.constexpr,
|
| 786 |
+
):
|
| 787 |
+
# Calculate final pointer if strides are provided
|
| 788 |
+
if stride_m is not None and stride_n is not None:
|
| 789 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 790 |
+
|
| 791 |
+
# Handle all masking cases
|
| 792 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 793 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 794 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 795 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 796 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 797 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 798 |
+
else: # Both divisible
|
| 799 |
+
return tl.load(ptr)
|
progress/SpecForge/cache/compiled_kernels/2y/c2yhndikcsebqfmbw7l44gmcdoyw7ogaqt7quyeygz3mp5w6u6ke.py
ADDED
|
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['4_inference']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6n/c6n4rf57opno6rcuedu4jk4etcok4ti2tlaztx2ht3z5eydc3vae.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# flex_attention => flex_attention
|
| 41 |
+
# Graph fragment:
|
| 42 |
+
# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=arg1_1]
|
| 43 |
+
# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg3_1]
|
| 44 |
+
# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg5_1]
|
| 45 |
+
# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=getitem_1]
|
| 46 |
+
# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf1]
|
| 47 |
+
# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:1" = PlaceHolder[target=arg9_1]
|
| 48 |
+
# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:1" = PlaceHolder[target=arg6_1]
|
| 49 |
+
# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:1" = PlaceHolder[target=arg10_1]
|
| 50 |
+
# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:1" = PlaceHolder[target=arg11_1]
|
| 51 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 52 |
+
# return %getitem
|
| 53 |
+
triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
|
| 54 |
+
import triton
|
| 55 |
+
import triton.language as tl
|
| 56 |
+
|
| 57 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 58 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 59 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 60 |
+
|
| 61 |
+
@triton_heuristics.template(
|
| 62 |
+
|
| 63 |
+
num_stages=3,
|
| 64 |
+
num_warps=8,
|
| 65 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 66 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 67 |
+
|
| 68 |
+
)
|
| 69 |
+
@triton.jit
|
| 70 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
|
| 71 |
+
PRESCALE_QK : tl.constexpr = False
|
| 72 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 73 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 74 |
+
WRITE_DQ : tl.constexpr = True
|
| 75 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 76 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 77 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 78 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 79 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 80 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 81 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 82 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 83 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 84 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 85 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 86 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 87 |
+
USE_TMA : tl.constexpr = False
|
| 88 |
+
BLOCK_M : tl.constexpr = 128
|
| 89 |
+
BLOCK_N : tl.constexpr = 64
|
| 90 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 91 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 92 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 93 |
+
Q = arg_Q
|
| 94 |
+
K = arg_K
|
| 95 |
+
V = arg_V
|
| 96 |
+
LSE = arg_LSE
|
| 97 |
+
MAX = arg_MAX
|
| 98 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 99 |
+
KV_IDX = arg_KV_IDX
|
| 100 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 101 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 102 |
+
|
| 103 |
+
# Sub notation for this kernel:
|
| 104 |
+
#
|
| 105 |
+
# Q: Query, K: Key, V: Value
|
| 106 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 107 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 108 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 109 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 110 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 111 |
+
#
|
| 112 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 113 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 114 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 115 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 116 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 117 |
+
#
|
| 118 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 119 |
+
#
|
| 120 |
+
# (Modifiable) Performance tuning options
|
| 121 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 122 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 123 |
+
|
| 124 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 125 |
+
# or involve a numerics vs. perf tradeoff
|
| 126 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 127 |
+
# about 20% more numerical error, but slightly faster.
|
| 128 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 129 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 130 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 131 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 132 |
+
|
| 133 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 134 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 135 |
+
|
| 136 |
+
# Define strides of inputs
|
| 137 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
|
| 138 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 139 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
|
| 140 |
+
|
| 141 |
+
ZQ = 1
|
| 142 |
+
HQ = 32
|
| 143 |
+
Q_LEN = ks0
|
| 144 |
+
ZKV = 1
|
| 145 |
+
KV_LEN = ks1
|
| 146 |
+
|
| 147 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 148 |
+
|
| 149 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 150 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 151 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 152 |
+
|
| 153 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 154 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 155 |
+
off_zkv = off_zq % ZKV
|
| 156 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 157 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 158 |
+
|
| 159 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 160 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 161 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 162 |
+
|
| 163 |
+
Q = Q + q_offset
|
| 164 |
+
K = K + k_offset
|
| 165 |
+
V = V + v_offset
|
| 166 |
+
|
| 167 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 168 |
+
desc_q = None
|
| 169 |
+
desc_k = None
|
| 170 |
+
desc_v = None
|
| 171 |
+
|
| 172 |
+
SPARSE_Z = 1
|
| 173 |
+
SPARSE_HQ = 1
|
| 174 |
+
|
| 175 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 176 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 177 |
+
|
| 178 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 179 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 180 |
+
|
| 181 |
+
stride_kv_num_blks_h = 1
|
| 182 |
+
stride_kv_idx_h = 1
|
| 183 |
+
stride_kv_idx_m = 1
|
| 184 |
+
|
| 185 |
+
# initialize pointer to m and l
|
| 186 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 187 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 188 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 189 |
+
|
| 190 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 191 |
+
|
| 192 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 193 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 194 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 195 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 196 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 197 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 198 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 199 |
+
|
| 200 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 201 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 202 |
+
# both score_mod and mask_mod to it
|
| 203 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 204 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 205 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 206 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# K and V pointers will be passed directly to forward_inner
|
| 210 |
+
|
| 211 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
acc, l_i, m_i = forward_inner(
|
| 215 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 216 |
+
q, K, V,
|
| 217 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 218 |
+
acc, l_i, m_i,
|
| 219 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 220 |
+
kv_start,
|
| 221 |
+
kv_indices, kv_num_blocks,
|
| 222 |
+
0, block_n_end,
|
| 223 |
+
MATMUL_PRECISION,
|
| 224 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 225 |
+
IS_FULL_BLOCKS=False,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 229 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 230 |
+
# apply mask_mod to them - only score_mod
|
| 231 |
+
if HAS_FULL_BLOCKS:
|
| 232 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 233 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 234 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 235 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 236 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 237 |
+
# K and V pointers will be passed directly to forward_inner
|
| 238 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 239 |
+
|
| 240 |
+
acc, l_i, m_i = forward_inner(
|
| 241 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 242 |
+
q, K, V,
|
| 243 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 244 |
+
acc, l_i, m_i,
|
| 245 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 246 |
+
kv_start,
|
| 247 |
+
kv_indices, kv_num_blocks,
|
| 248 |
+
0, block_n_end,
|
| 249 |
+
MATMUL_PRECISION,
|
| 250 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 251 |
+
IS_FULL_BLOCKS=True,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# [Note] Handle fully masked out rows:
|
| 256 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 257 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 258 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 259 |
+
|
| 260 |
+
acc = acc / l_i[:, None]
|
| 261 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 262 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 263 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 264 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 265 |
+
|
| 266 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 267 |
+
|
| 268 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 269 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
|
| 270 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 271 |
+
|
| 272 |
+
if OUTPUT_LOGSUMEXP:
|
| 273 |
+
off_hz = off_zq * HQ + off_hq
|
| 274 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 275 |
+
lse = m_i + tl.math.log2(l_i)
|
| 276 |
+
if IS_DIVISIBLE:
|
| 277 |
+
tl.store(l_ptrs, lse)
|
| 278 |
+
else:
|
| 279 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 280 |
+
|
| 281 |
+
if OUTPUT_MAX:
|
| 282 |
+
off_hz = off_zq * HQ + off_hq
|
| 283 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 284 |
+
if IS_DIVISIBLE:
|
| 285 |
+
tl.store(max_ptrs, m_i)
|
| 286 |
+
else:
|
| 287 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# Utility triton funcs
|
| 291 |
+
@triton.jit
|
| 292 |
+
def get_offset_for_next_block(
|
| 293 |
+
loop_iter, col_indices, total_blocks,
|
| 294 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 295 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 296 |
+
):
|
| 297 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 298 |
+
return BLOCK
|
| 299 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 300 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 301 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 302 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 303 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 304 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 305 |
+
return offset
|
| 306 |
+
|
| 307 |
+
@triton.jit
|
| 308 |
+
def get_bounded_indices(indices, max_len=None):
|
| 309 |
+
return indices % max_len if max_len is not None else indices
|
| 310 |
+
|
| 311 |
+
@triton.jit
|
| 312 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 313 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 314 |
+
return tl.load(block_ptr)
|
| 315 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 316 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 317 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 318 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 319 |
+
else:
|
| 320 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 321 |
+
|
| 322 |
+
@triton.jit
|
| 323 |
+
def load_checked_2d(
|
| 324 |
+
ptr,
|
| 325 |
+
offs_m,
|
| 326 |
+
offs_n,
|
| 327 |
+
stride_m,
|
| 328 |
+
stride_n,
|
| 329 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 330 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 331 |
+
M_LEN: tl.constexpr,
|
| 332 |
+
N_LEN: tl.constexpr,
|
| 333 |
+
):
|
| 334 |
+
# Calculate final pointer if strides are provided
|
| 335 |
+
if stride_m is not None and stride_n is not None:
|
| 336 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 337 |
+
|
| 338 |
+
# Handle all masking cases
|
| 339 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 340 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 341 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 342 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 343 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 344 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 345 |
+
else: # Both divisible
|
| 346 |
+
return tl.load(ptr)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# Common Imports
|
| 350 |
+
@triton.jit
|
| 351 |
+
def forward_block_mn(
|
| 352 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 353 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 354 |
+
# accumulated values
|
| 355 |
+
acc, l_i, m_i,
|
| 356 |
+
# Offsets
|
| 357 |
+
off_z, off_h, offs_m, offs_n,
|
| 358 |
+
# Offsets needed for TMA loads
|
| 359 |
+
kv_start,
|
| 360 |
+
kv_offset,
|
| 361 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 362 |
+
# Strides for K and V
|
| 363 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 364 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 365 |
+
|
| 366 |
+
):
|
| 367 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 368 |
+
PRESCALE_QK : tl.constexpr = False
|
| 369 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 370 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 371 |
+
WRITE_DQ : tl.constexpr = True
|
| 372 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 373 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 374 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 375 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 376 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 377 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 378 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 379 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 380 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 381 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 382 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 383 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 384 |
+
USE_TMA : tl.constexpr = False
|
| 385 |
+
BLOCK_M : tl.constexpr = 128
|
| 386 |
+
BLOCK_N : tl.constexpr = 64
|
| 387 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 388 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 389 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# -- load k --
|
| 393 |
+
# NB reversed order to since K is transposed
|
| 394 |
+
kv_base_offset = kv_start + kv_offset
|
| 395 |
+
|
| 396 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 397 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 398 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 399 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 400 |
+
|
| 401 |
+
k = tl.trans(k)
|
| 402 |
+
# -- compute qk ---
|
| 403 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 404 |
+
if not PRESCALE_QK:
|
| 405 |
+
qk *= SM_SCALE
|
| 406 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 407 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 408 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 409 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 410 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 411 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 412 |
+
|
| 413 |
+
tmp0 = (qk)
|
| 414 |
+
post_mod_scores = tmp0
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 418 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 419 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 420 |
+
|
| 421 |
+
if not IS_FULL_BLOCKS:
|
| 422 |
+
tmp1 = (m)
|
| 423 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 424 |
+
tmp3 = tmp1 < tmp2
|
| 425 |
+
tmp4 = (n)
|
| 426 |
+
tmp5 = tmp4 <= tmp1
|
| 427 |
+
tmp6 = tmp3 & tmp5
|
| 428 |
+
tmp7 = tmp1 >= tmp2
|
| 429 |
+
tmp8 = tmp4 < tmp2
|
| 430 |
+
tmp9 = tmp7 & tmp8
|
| 431 |
+
tmp10 = tmp8 == 0
|
| 432 |
+
tmp11 = tmp7 & tmp10
|
| 433 |
+
tmp12 = tmp1 - tmp2
|
| 434 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 435 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 436 |
+
tmp15 = tmp4 - tmp2
|
| 437 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 438 |
+
tmp17 = tmp14 == tmp16
|
| 439 |
+
tmp18 = tmp11 & tmp17
|
| 440 |
+
tmp19 = tmp9 | tmp18
|
| 441 |
+
tmp20 = tmp6 | tmp19
|
| 442 |
+
mask_mod_output = tmp20
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 446 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 447 |
+
# apply mask for partially unmasked blocks
|
| 448 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 449 |
+
|
| 450 |
+
if not PRESCALE_QK:
|
| 451 |
+
post_mod_scores *= RCP_LN2
|
| 452 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 453 |
+
|
| 454 |
+
# -- compute scaling constant ---
|
| 455 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 456 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 457 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 458 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 459 |
+
else:
|
| 460 |
+
m_ij_masked = m_ij
|
| 461 |
+
|
| 462 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 463 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 464 |
+
|
| 465 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 466 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 467 |
+
# m_ij
|
| 468 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 469 |
+
# # -- scale and update acc --
|
| 470 |
+
acc = acc * alpha[:, None]
|
| 471 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 472 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 473 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 474 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 475 |
+
|
| 476 |
+
# -- update m_i
|
| 477 |
+
m_i = m_ij
|
| 478 |
+
|
| 479 |
+
return acc, l_i, m_i
|
| 480 |
+
|
| 481 |
+
@triton.jit
|
| 482 |
+
def forward_inner(
|
| 483 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 484 |
+
q, K, V,
|
| 485 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 486 |
+
# accumulated values
|
| 487 |
+
acc, l_i, m_i,
|
| 488 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 489 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 490 |
+
off_z, off_h, offs_m, offs_n,
|
| 491 |
+
# Offsets needed for TMA loads
|
| 492 |
+
kv_start,
|
| 493 |
+
# blocksparse data
|
| 494 |
+
kv_indices, kv_num_blocks,
|
| 495 |
+
# start kv and end kv block
|
| 496 |
+
block_n_start, block_n_end,
|
| 497 |
+
MATMUL_PRECISION,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
):
|
| 502 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 503 |
+
PRESCALE_QK : tl.constexpr = False
|
| 504 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 505 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 506 |
+
WRITE_DQ : tl.constexpr = True
|
| 507 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 508 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 509 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 510 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 511 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 512 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 513 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 514 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 515 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 516 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 517 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 518 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 519 |
+
USE_TMA : tl.constexpr = False
|
| 520 |
+
BLOCK_M : tl.constexpr = 128
|
| 521 |
+
BLOCK_N : tl.constexpr = 64
|
| 522 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 523 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 524 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 528 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 529 |
+
|
| 530 |
+
if PRESCALE_QK:
|
| 531 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 532 |
+
|
| 533 |
+
kv_offset = 0
|
| 534 |
+
|
| 535 |
+
# loop over k, v and update accumulator until block_n_end
|
| 536 |
+
for start_n in range(block_n_start, block_n_end):
|
| 537 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 538 |
+
if IS_DIVISIBLE:
|
| 539 |
+
acc, l_i, m_i = forward_block_mn(
|
| 540 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 541 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 542 |
+
# accumulated values
|
| 543 |
+
acc, l_i, m_i,
|
| 544 |
+
# Offsets
|
| 545 |
+
off_z, off_h, offs_m, offs_n,
|
| 546 |
+
# Offsets needed for TMA loads
|
| 547 |
+
kv_start,
|
| 548 |
+
kv_offset,
|
| 549 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 550 |
+
# Strides for K and V
|
| 551 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 552 |
+
IS_FULL_BLOCKS,
|
| 553 |
+
)
|
| 554 |
+
else:
|
| 555 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 556 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 557 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 558 |
+
# to the last block because it's faster a lot.
|
| 559 |
+
acc, l_i, m_i = forward_block_mn(
|
| 560 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 561 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 562 |
+
# accumulated values
|
| 563 |
+
acc, l_i, m_i,
|
| 564 |
+
# Offsets
|
| 565 |
+
off_z, off_h, offs_m, offs_n,
|
| 566 |
+
# Offsets needed for TMA loads
|
| 567 |
+
kv_start,
|
| 568 |
+
kv_offset,
|
| 569 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 570 |
+
# Strides for K and V
|
| 571 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 572 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
offset = get_offset_for_next_block(
|
| 578 |
+
start_n, kv_indices, kv_num_blocks,
|
| 579 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
offs_n = offs_n + offset
|
| 583 |
+
kv_offset += offset
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
return acc, l_i, m_i
|
| 587 |
+
''', device_str='cuda')
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/qr/cqrng7hmawuvea5b46xnw26e3vaokywqdqnuhn4vt7tmtdoleeab.py
|
| 591 |
+
# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
|
| 592 |
+
# Source node to ATen node mapping:
|
| 593 |
+
# lse_scaled => mul_9
|
| 594 |
+
# Graph fragment:
|
| 595 |
+
# %buf3 : Tensor = PlaceHolder[target=buf3]
|
| 596 |
+
# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
|
| 597 |
+
# return %mul_9
|
| 598 |
+
triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', '''
|
| 599 |
+
import triton
|
| 600 |
+
import triton.language as tl
|
| 601 |
+
|
| 602 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 603 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 604 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 605 |
+
triton_helpers.set_driver_to_gpu()
|
| 606 |
+
|
| 607 |
+
@triton_heuristics.pointwise(
|
| 608 |
+
size_hints={'x': 4096},
|
| 609 |
+
filename=__file__,
|
| 610 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 611 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 612 |
+
min_elem_per_thread=0
|
| 613 |
+
)
|
| 614 |
+
@triton.jit
|
| 615 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 616 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 617 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 618 |
+
xmask = xindex < xnumel
|
| 619 |
+
x2 = xindex
|
| 620 |
+
x0 = (xindex % ks0)
|
| 621 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 622 |
+
tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
|
| 623 |
+
tmp1 = 0.6931471805599453
|
| 624 |
+
tmp2 = tmp0 * tmp1
|
| 625 |
+
tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
|
| 626 |
+
''', device_str='cuda')
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
async_compile.wait(globals())
|
| 630 |
+
del async_compile
|
| 631 |
+
|
| 632 |
+
class Runner:
|
| 633 |
+
def __init__(self, partitions):
|
| 634 |
+
self.partitions = partitions
|
| 635 |
+
|
| 636 |
+
def recursively_apply_fns(self, fns):
|
| 637 |
+
new_callables = []
|
| 638 |
+
for fn, c in zip(fns, self.partitions):
|
| 639 |
+
new_callables.append(fn(c))
|
| 640 |
+
self.partitions = new_callables
|
| 641 |
+
|
| 642 |
+
def call(self, args):
|
| 643 |
+
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args
|
| 644 |
+
args.clear()
|
| 645 |
+
s50 = arg0_1
|
| 646 |
+
s0 = arg2_1
|
| 647 |
+
s43 = arg4_1
|
| 648 |
+
s37 = arg7_1
|
| 649 |
+
s71 = arg8_1
|
| 650 |
+
assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 651 |
+
assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 652 |
+
assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1))
|
| 653 |
+
assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 654 |
+
assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1))
|
| 655 |
+
assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1))
|
| 656 |
+
assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 657 |
+
assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1))
|
| 658 |
+
assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 659 |
+
assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1))
|
| 660 |
+
assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 661 |
+
with torch.cuda._DeviceGuard(1):
|
| 662 |
+
torch.cuda.set_device(1)
|
| 663 |
+
buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 664 |
+
buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 665 |
+
buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
|
| 666 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 667 |
+
stream1 = get_raw_stream(1)
|
| 668 |
+
triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream1)
|
| 669 |
+
del arg10_1
|
| 670 |
+
del arg11_1
|
| 671 |
+
del arg1_1
|
| 672 |
+
del arg3_1
|
| 673 |
+
del arg5_1
|
| 674 |
+
del arg6_1
|
| 675 |
+
del arg9_1
|
| 676 |
+
del buf1
|
| 677 |
+
buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
|
| 678 |
+
# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
|
| 679 |
+
triton_poi_fused_mul_1_xnumel = 32*s37
|
| 680 |
+
stream1 = get_raw_stream(1)
|
| 681 |
+
triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream1)
|
| 682 |
+
del buf0
|
| 683 |
+
return (buf2, buf5, )
|
| 684 |
+
|
| 685 |
+
runner = Runner(partitions=[])
|
| 686 |
+
call = runner.call
|
| 687 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 691 |
+
from torch._dynamo.testing import rand_strided
|
| 692 |
+
from torch._inductor.utils import print_performance
|
| 693 |
+
arg0_1 = 128
|
| 694 |
+
arg1_1 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16)
|
| 695 |
+
arg2_1 = 128
|
| 696 |
+
arg3_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16)
|
| 697 |
+
arg4_1 = 128
|
| 698 |
+
arg5_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16)
|
| 699 |
+
arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32)
|
| 700 |
+
arg7_1 = 128
|
| 701 |
+
arg8_1 = 128
|
| 702 |
+
arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32)
|
| 703 |
+
arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32)
|
| 704 |
+
arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32)
|
| 705 |
+
arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32)
|
| 706 |
+
arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32)
|
| 707 |
+
arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32)
|
| 708 |
+
arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32)
|
| 709 |
+
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1])
|
| 710 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
if __name__ == "__main__":
|
| 714 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 715 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/2z/c2zdv5arszdl6ednyphqfnib6jwgzomr6zt6536b7gq75kp67uvh.py
ADDED
|
@@ -0,0 +1,1046 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['2_backward']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sh/csh76hcjkj7bc6jvydzdmaapo6vnfxlvc3xvqexzngu63td4qnjk.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# Graph fragment:
|
| 41 |
+
# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem]
|
| 42 |
+
# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1]
|
| 43 |
+
# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf0]
|
| 44 |
+
# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4" = PlaceHolder[target=tangents_2]
|
| 45 |
+
# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
|
| 46 |
+
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 47 |
+
# return %buf0,%buf1
|
| 48 |
+
triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', '''
|
| 49 |
+
import triton
|
| 50 |
+
import triton.language as tl
|
| 51 |
+
|
| 52 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 53 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 54 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 55 |
+
triton_helpers.set_driver_to_gpu()
|
| 56 |
+
|
| 57 |
+
@triton_heuristics.reduction(
|
| 58 |
+
size_hints={'x': 32768, 'r0_': 128},
|
| 59 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 60 |
+
filename=__file__,
|
| 61 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
|
| 62 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 63 |
+
)
|
| 64 |
+
@triton.jit
|
| 65 |
+
def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 66 |
+
r0_numel = 128
|
| 67 |
+
rnumel = r0_numel
|
| 68 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 69 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 70 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 71 |
+
xmask = xindex < xnumel
|
| 72 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 73 |
+
rbase = r0_base
|
| 74 |
+
x0 = (xindex % ks0)
|
| 75 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 76 |
+
_tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 77 |
+
x3 = xindex
|
| 78 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 79 |
+
r0_index = r0_offset + r0_base
|
| 80 |
+
r0_mask = r0_index < r0_numel
|
| 81 |
+
roffset = r0_offset
|
| 82 |
+
rindex = r0_index
|
| 83 |
+
r0_2 = r0_index
|
| 84 |
+
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 85 |
+
tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 86 |
+
tmp2 = tmp0 * tmp1
|
| 87 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 88 |
+
tmp5 = _tmp4 + tmp3
|
| 89 |
+
_tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
|
| 90 |
+
tmp4 = tl.sum(_tmp4, 1)[:, None]
|
| 91 |
+
tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
|
| 92 |
+
tmp6 = tmp4.to(tl.float32)
|
| 93 |
+
tmp8 = 0.6931471805599453
|
| 94 |
+
tmp9 = tmp7 * tmp8
|
| 95 |
+
tmp10 = 1.4426950408889634
|
| 96 |
+
tmp11 = tmp9 * tmp10
|
| 97 |
+
tmp12 = tmp6 - tmp11
|
| 98 |
+
tl.store(out_ptr1 + (x3), tmp12, xmask)
|
| 99 |
+
''', device_str='cuda')
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ek/cekdygnnt4twwaq4fpapciid2veg5uc5gzwte4mymxe7ertv26cs.py
|
| 103 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 104 |
+
# Source node to ATen node mapping:
|
| 105 |
+
# Graph fragment:
|
| 106 |
+
# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_2]
|
| 107 |
+
# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_4]
|
| 108 |
+
# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_6]
|
| 109 |
+
# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4" = PlaceHolder[target=getitem_1]
|
| 110 |
+
# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1]
|
| 111 |
+
# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1]
|
| 112 |
+
# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem_3]
|
| 113 |
+
# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=getitem_5]
|
| 114 |
+
# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:4" = PlaceHolder[target=primals_13]
|
| 115 |
+
# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:4" = PlaceHolder[target=primals_9]
|
| 116 |
+
# %primals_20 : Tensor "i32[1, 1, s56][s56, s56, 1]cuda:4" = PlaceHolder[target=primals_20]
|
| 117 |
+
# %primals_23 : Tensor "i32[1, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:4" = PlaceHolder[target=primals_23]
|
| 118 |
+
# %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:4" = PlaceHolder[target=primals_15]
|
| 119 |
+
# %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:4" = PlaceHolder[target=primals_18]
|
| 120 |
+
# %primals_25 : Tensor "i32[1, 1, s100][s100, s100, 1]cuda:4" = PlaceHolder[target=primals_25]
|
| 121 |
+
# %primals_28 : Tensor "i32[1, 1, s5, s10][s10*s5, s10*s5, s10, 1]cuda:4" = PlaceHolder[target=primals_28]
|
| 122 |
+
# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
|
| 123 |
+
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 124 |
+
# return %getitem_4
|
| 125 |
+
triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', '''
|
| 126 |
+
import triton
|
| 127 |
+
import triton.language as tl
|
| 128 |
+
|
| 129 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 130 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 131 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 132 |
+
|
| 133 |
+
@triton_heuristics.template(
|
| 134 |
+
|
| 135 |
+
num_stages=3,
|
| 136 |
+
num_warps=8,
|
| 137 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 138 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 139 |
+
|
| 140 |
+
)
|
| 141 |
+
@triton.jit
|
| 142 |
+
def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7):
|
| 143 |
+
PRESCALE_QK : tl.constexpr = False
|
| 144 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 145 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 146 |
+
WRITE_DQ : tl.constexpr = True
|
| 147 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 148 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 149 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 150 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 151 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 152 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 153 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 154 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 155 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 156 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 157 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 158 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 159 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 160 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 161 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 162 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 163 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 164 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 165 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 166 |
+
Q = arg_Q
|
| 167 |
+
K = arg_K
|
| 168 |
+
V = arg_V
|
| 169 |
+
LSE = arg_LSE
|
| 170 |
+
DELTA = arg_DELTA
|
| 171 |
+
DO = arg_DO
|
| 172 |
+
DQ = arg_DQ
|
| 173 |
+
DV = arg_DV
|
| 174 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 175 |
+
KV_IDX = arg_KV_IDX
|
| 176 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 177 |
+
Q_IDX = arg_Q_IDX
|
| 178 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 179 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 180 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 181 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 182 |
+
|
| 183 |
+
# Sub notation for this kernel:
|
| 184 |
+
#
|
| 185 |
+
# Q: Query, K: Key, V: Value
|
| 186 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 187 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 188 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 189 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 190 |
+
# inductor codegen
|
| 191 |
+
# M: Number of queries, N: Number of keys/values
|
| 192 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 193 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 194 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 195 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 196 |
+
# (Modifiable) Performance tuning options
|
| 197 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 198 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 199 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 200 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 201 |
+
#
|
| 202 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 203 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 204 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 205 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 206 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 207 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 208 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 209 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 210 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 211 |
+
|
| 212 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 213 |
+
# or involve a numerics vs. perf tradeoff
|
| 214 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 215 |
+
# about 20% more numerical error, but slightly faster.
|
| 216 |
+
|
| 217 |
+
# Define strides of inputs
|
| 218 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
|
| 219 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
|
| 220 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
|
| 221 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
|
| 222 |
+
|
| 223 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
|
| 224 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
|
| 225 |
+
|
| 226 |
+
ZQ = 1
|
| 227 |
+
HQ = 32
|
| 228 |
+
HKV = 8
|
| 229 |
+
Q_LEN = ks0
|
| 230 |
+
ZKV = 1
|
| 231 |
+
KV_LEN = ks1
|
| 232 |
+
|
| 233 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 234 |
+
|
| 235 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 236 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 237 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 238 |
+
|
| 239 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 240 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 241 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 242 |
+
|
| 243 |
+
SPARSE_Z = 1
|
| 244 |
+
SPARSE_HQ = 1
|
| 245 |
+
|
| 246 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 247 |
+
|
| 248 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 249 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 250 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 251 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 252 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 253 |
+
|
| 254 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 255 |
+
K += k_adj
|
| 256 |
+
V += v_adj
|
| 257 |
+
DV += dv_adj
|
| 258 |
+
|
| 259 |
+
RCP_LN2 = 1.44269504
|
| 260 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 261 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 262 |
+
|
| 263 |
+
if pid >= NUM_KV_BLOCKS:
|
| 264 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 265 |
+
# THIS BLOCK DOES DQ
|
| 266 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 267 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 268 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 269 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 270 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 271 |
+
stride_kv_num_blks_h = ks2
|
| 272 |
+
stride_kv_idx_h = ks3*ks4
|
| 273 |
+
stride_kv_idx_m = ks4
|
| 274 |
+
|
| 275 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 276 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 277 |
+
|
| 278 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 279 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 280 |
+
|
| 281 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 282 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 283 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 284 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 285 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 286 |
+
|
| 287 |
+
Q2 = Q + q_adj2
|
| 288 |
+
DO2 = DO + do_adj2
|
| 289 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 290 |
+
# if Q is broadcasted)
|
| 291 |
+
DQ2 = DQ + dq_adj2
|
| 292 |
+
LSE2 = LSE + off_chz2
|
| 293 |
+
DELTA2 = DELTA + off_chz2
|
| 294 |
+
|
| 295 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 296 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 297 |
+
|
| 298 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 299 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 300 |
+
|
| 301 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 302 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 303 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 304 |
+
|
| 305 |
+
if PRESCALE_QK:
|
| 306 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 307 |
+
|
| 308 |
+
if IS_DIVISIBLE:
|
| 309 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 310 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 311 |
+
else:
|
| 312 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 313 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 314 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 315 |
+
lse = lse[:, None]
|
| 316 |
+
|
| 317 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 318 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 319 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 320 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 321 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 322 |
+
|
| 323 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 324 |
+
dq = bwd_dq_inner(
|
| 325 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 326 |
+
K, V,
|
| 327 |
+
dq, q, do, Di, lse,
|
| 328 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 329 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 330 |
+
kv_indices, sparse_kv_num_blocks,
|
| 331 |
+
MATMUL_PRECISION,
|
| 332 |
+
IS_FULL_BLOCKS=False,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if HAS_FULL_BLOCKS:
|
| 336 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 337 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 338 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 339 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 340 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 341 |
+
|
| 342 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 343 |
+
dq = bwd_dq_inner(
|
| 344 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 345 |
+
K, V,
|
| 346 |
+
dq, q, do, Di, lse,
|
| 347 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 348 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 349 |
+
kv_indices, sparse_kv_num_blocks,
|
| 350 |
+
MATMUL_PRECISION,
|
| 351 |
+
IS_FULL_BLOCKS=True,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Write back dQ.
|
| 355 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 356 |
+
dq *= SM_SCALE
|
| 357 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 358 |
+
tl.store(dq_ptrs, dq)
|
| 359 |
+
else:
|
| 360 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 361 |
+
else:
|
| 362 |
+
# THIS BLOCK DOES DK & DV
|
| 363 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 364 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 365 |
+
|
| 366 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 367 |
+
|
| 368 |
+
stride_q_num_blks_h = ks5
|
| 369 |
+
stride_q_idx_h = ks6*ks7
|
| 370 |
+
stride_q_idx_n = ks6
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 374 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 375 |
+
|
| 376 |
+
start_n1 = pid * BLOCK_N1
|
| 377 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 378 |
+
|
| 379 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 380 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 381 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 382 |
+
|
| 383 |
+
if PRESCALE_QK:
|
| 384 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 385 |
+
|
| 386 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 387 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 388 |
+
|
| 389 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 390 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 391 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 392 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 393 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 394 |
+
|
| 395 |
+
Q1 = Q + q_adj1
|
| 396 |
+
DO1 = DO + do_adj1
|
| 397 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 398 |
+
# if Q is broadcasted)
|
| 399 |
+
LSE1 = LSE + off_chz1
|
| 400 |
+
DELTA1 = DELTA + off_chz1
|
| 401 |
+
|
| 402 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 403 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 404 |
+
|
| 405 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 406 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 407 |
+
|
| 408 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 409 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 410 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 411 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 412 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 413 |
+
|
| 414 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 415 |
+
dk, dv = bwd_dkdv_inner(
|
| 416 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 417 |
+
Q1, DO1, DELTA1, LSE1,
|
| 418 |
+
dk, dv, k, v,
|
| 419 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 420 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 421 |
+
q_indices, sparse_q_num_blocks,
|
| 422 |
+
MATMUL_PRECISION,
|
| 423 |
+
IS_FULL_BLOCKS=False,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
if HAS_FULL_BLOCKS:
|
| 428 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 429 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 430 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 431 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 432 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 433 |
+
|
| 434 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 435 |
+
dk, dv = bwd_dkdv_inner(
|
| 436 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 437 |
+
Q1, DO1, DELTA1, LSE1,
|
| 438 |
+
dk, dv, k, v,
|
| 439 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 440 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 441 |
+
q_indices, sparse_q_num_blocks,
|
| 442 |
+
MATMUL_PRECISION,
|
| 443 |
+
IS_FULL_BLOCKS=True,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Write back dV and dK.
|
| 447 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 448 |
+
|
| 449 |
+
index_n = offs_n1[:, None]
|
| 450 |
+
index_k = offs_k[None, :]
|
| 451 |
+
index_v = offs_v[None, :]
|
| 452 |
+
|
| 453 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 454 |
+
tl.store(dv_ptrs, dv)
|
| 455 |
+
else:
|
| 456 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 457 |
+
|
| 458 |
+
dk *= SM_SCALE
|
| 459 |
+
|
| 460 |
+
if SAFE_HEAD_DIM:
|
| 461 |
+
mask = index_n < KV_LEN
|
| 462 |
+
else:
|
| 463 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 464 |
+
|
| 465 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 466 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 467 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 468 |
+
xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
|
| 469 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 470 |
+
|
| 471 |
+
@triton.jit
|
| 472 |
+
def bwd_dq_inner(
|
| 473 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 474 |
+
K, V, # pointers
|
| 475 |
+
dq, q, do, Di, lse,
|
| 476 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 477 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 478 |
+
kv_indices, sparse_kv_num_blocks,
|
| 479 |
+
MATMUL_PRECISION,
|
| 480 |
+
IS_FULL_BLOCKS,
|
| 481 |
+
):
|
| 482 |
+
PRESCALE_QK : tl.constexpr = False
|
| 483 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 484 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 485 |
+
WRITE_DQ : tl.constexpr = True
|
| 486 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 487 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 488 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 489 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 490 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 491 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 492 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 493 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 494 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 495 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 496 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 497 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 498 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 499 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 500 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 501 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 502 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 503 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 504 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 505 |
+
|
| 506 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 507 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 508 |
+
Q_LEN = ks0
|
| 509 |
+
KV_LEN = ks1
|
| 510 |
+
|
| 511 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 512 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 513 |
+
|
| 514 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 515 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 516 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 517 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 518 |
+
|
| 519 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 520 |
+
|
| 521 |
+
for start_n in range(0, hi):
|
| 522 |
+
dq = bwd_dq_block_mn(
|
| 523 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 524 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 525 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 526 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 527 |
+
kv_indices, sparse_kv_num_blocks,
|
| 528 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 529 |
+
IS_FULL_BLOCKS,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
# Increment pointers.
|
| 533 |
+
offset = get_offset_for_next_block(
|
| 534 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 535 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
kT_ptrs += offset * stride_kn
|
| 539 |
+
vT_ptrs += offset * stride_vn
|
| 540 |
+
|
| 541 |
+
offs_n2 += offset
|
| 542 |
+
|
| 543 |
+
return dq
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
@triton.jit
|
| 547 |
+
def bwd_dq_block_mn(
|
| 548 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 549 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 550 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 551 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 552 |
+
kv_indices, sparse_kv_num_blocks,
|
| 553 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 554 |
+
IS_FULL_BLOCKS,
|
| 555 |
+
):
|
| 556 |
+
PRESCALE_QK : tl.constexpr = False
|
| 557 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 558 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 559 |
+
WRITE_DQ : tl.constexpr = True
|
| 560 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 561 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 562 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 563 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 564 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 565 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 566 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 567 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 568 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 569 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 570 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 571 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 572 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 573 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 574 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 575 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 576 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 577 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 578 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
# NB reversed order to since K is transposed
|
| 582 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 583 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 584 |
+
if not PRESCALE_QK:
|
| 585 |
+
qk *= SM_SCALE
|
| 586 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 587 |
+
pre_mod_scores = qk
|
| 588 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 589 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 590 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 591 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 592 |
+
|
| 593 |
+
tmp0 = (qk)
|
| 594 |
+
post_mod_scores = tmp0
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
if not IS_DIVISIBLE:
|
| 600 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 601 |
+
|
| 602 |
+
if not IS_FULL_BLOCKS:
|
| 603 |
+
tmp1 = (m)
|
| 604 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 605 |
+
tmp3 = tmp1 < tmp2
|
| 606 |
+
tmp4 = (n)
|
| 607 |
+
tmp5 = tmp4 <= tmp1
|
| 608 |
+
tmp6 = tmp3 & tmp5
|
| 609 |
+
tmp7 = tmp1 >= tmp2
|
| 610 |
+
tmp8 = tmp4 < tmp2
|
| 611 |
+
tmp9 = tmp7 & tmp8
|
| 612 |
+
tmp10 = tmp8 == 0
|
| 613 |
+
tmp11 = tmp7 & tmp10
|
| 614 |
+
tmp12 = tmp1 - tmp2
|
| 615 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 616 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 617 |
+
tmp15 = tmp4 - tmp2
|
| 618 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 619 |
+
tmp17 = tmp14 == tmp16
|
| 620 |
+
tmp18 = tmp11 & tmp17
|
| 621 |
+
tmp19 = tmp9 | tmp18
|
| 622 |
+
tmp20 = tmp6 | tmp19
|
| 623 |
+
mask_mod_output = tmp20
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
# apply mask for partial masked block
|
| 627 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 628 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 629 |
+
if not PRESCALE_QK:
|
| 630 |
+
post_mod_scores *= RCP_LN2
|
| 631 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 632 |
+
# Compute dP and dS.
|
| 633 |
+
# NB reversed order to since V is transposed
|
| 634 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 635 |
+
|
| 636 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 637 |
+
ds = p * (dp - Di[:, None])
|
| 638 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 639 |
+
tmp21 = (ds)
|
| 640 |
+
grad_scores = tmp21
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
if not IS_DIVISIBLE:
|
| 644 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 645 |
+
|
| 646 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 647 |
+
if WRITE_DQ:
|
| 648 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 649 |
+
|
| 650 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 651 |
+
ds = grad_scores
|
| 652 |
+
|
| 653 |
+
if not IS_FULL_BLOCKS:
|
| 654 |
+
# (grads) apply mask for partially unmasked block
|
| 655 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 656 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 657 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 658 |
+
# Compute dQ.
|
| 659 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 660 |
+
|
| 661 |
+
return dq
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
@triton.jit
|
| 665 |
+
def bwd_dkdv_inner(
|
| 666 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 667 |
+
Q, DO, DELTA, LSE, # pointers
|
| 668 |
+
dk, dv, k, v,
|
| 669 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 670 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 671 |
+
q_indices, sparse_q_num_blocks,
|
| 672 |
+
MATMUL_PRECISION,
|
| 673 |
+
IS_FULL_BLOCKS,
|
| 674 |
+
):
|
| 675 |
+
PRESCALE_QK : tl.constexpr = False
|
| 676 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 677 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 678 |
+
WRITE_DQ : tl.constexpr = True
|
| 679 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 680 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 681 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 682 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 683 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 684 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 685 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 686 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 687 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 688 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 689 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 690 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 691 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 692 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 693 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 694 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 695 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 696 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 697 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 698 |
+
|
| 699 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 700 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 701 |
+
Q_LEN = ks0
|
| 702 |
+
KV_LEN = ks1
|
| 703 |
+
|
| 704 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 705 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 706 |
+
|
| 707 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 708 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 709 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 710 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 711 |
+
|
| 712 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 713 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 714 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 715 |
+
|
| 716 |
+
for start_m in range(0, hi):
|
| 717 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 718 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 719 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 720 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 721 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 722 |
+
q_indices, sparse_q_num_blocks,
|
| 723 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 724 |
+
IS_FULL_BLOCKS,
|
| 725 |
+
)
|
| 726 |
+
# Increment pointers.
|
| 727 |
+
offset = get_offset_for_next_block(
|
| 728 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 729 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
qT_ptrs += offset * stride_qm
|
| 733 |
+
do_ptrs += offset * stride_dom
|
| 734 |
+
offs_m1 += offset
|
| 735 |
+
|
| 736 |
+
return dk, dv
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
@triton.jit
|
| 740 |
+
def bwd_dkdv_block_mn(
|
| 741 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7,
|
| 742 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 743 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 744 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 745 |
+
q_indices, sparse_q_num_blocks,
|
| 746 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 747 |
+
IS_FULL_BLOCKS,
|
| 748 |
+
):
|
| 749 |
+
PRESCALE_QK : tl.constexpr = False
|
| 750 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 751 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 752 |
+
WRITE_DQ : tl.constexpr = True
|
| 753 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 754 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 755 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 756 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 757 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 758 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 759 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 760 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 761 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 762 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 763 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 764 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 765 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 766 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 767 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 768 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 769 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 770 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 771 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
# NB reversed order since Q is transposed
|
| 775 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 776 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 777 |
+
if IS_DIVISIBLE:
|
| 778 |
+
lse = tl.load(LSE + offs_m1)
|
| 779 |
+
else:
|
| 780 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 781 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 782 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 783 |
+
if not PRESCALE_QK:
|
| 784 |
+
qkT *= SM_SCALE
|
| 785 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 786 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 787 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 788 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 789 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 790 |
+
|
| 791 |
+
pre_mod_scores = qkT
|
| 792 |
+
tmp22 = (qkT)
|
| 793 |
+
post_mod_scores = tmp22
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
if not IS_DIVISIBLE:
|
| 798 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 799 |
+
|
| 800 |
+
if not IS_FULL_BLOCKS:
|
| 801 |
+
tmp23 = (m)
|
| 802 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 803 |
+
tmp25 = tmp23 < tmp24
|
| 804 |
+
tmp26 = (n)
|
| 805 |
+
tmp27 = tmp26 <= tmp23
|
| 806 |
+
tmp28 = tmp25 & tmp27
|
| 807 |
+
tmp29 = tmp23 >= tmp24
|
| 808 |
+
tmp30 = tmp26 < tmp24
|
| 809 |
+
tmp31 = tmp29 & tmp30
|
| 810 |
+
tmp32 = tmp30 == 0
|
| 811 |
+
tmp33 = tmp29 & tmp32
|
| 812 |
+
tmp34 = tmp23 - tmp24
|
| 813 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 814 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 815 |
+
tmp37 = tmp26 - tmp24
|
| 816 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 817 |
+
tmp39 = tmp36 == tmp38
|
| 818 |
+
tmp40 = tmp33 & tmp39
|
| 819 |
+
tmp41 = tmp31 | tmp40
|
| 820 |
+
tmp42 = tmp28 | tmp41
|
| 821 |
+
mask_mod_output = tmp42
|
| 822 |
+
|
| 823 |
+
# (grads) apply mask for fully masked block
|
| 824 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 825 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 826 |
+
if not PRESCALE_QK:
|
| 827 |
+
post_mod_scores *= RCP_LN2
|
| 828 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 829 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 830 |
+
# Compute dV.
|
| 831 |
+
ppT = pT
|
| 832 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 833 |
+
if IS_DIVISIBLE:
|
| 834 |
+
Di = tl.load(DELTA + offs_m1)
|
| 835 |
+
else:
|
| 836 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 837 |
+
# Compute dP and dS.
|
| 838 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 839 |
+
dsT = pT * (dpT - Di[None, :])
|
| 840 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 841 |
+
tmp43 = (dsT)
|
| 842 |
+
grad_scores = tmp43
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
if not IS_DIVISIBLE:
|
| 847 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 848 |
+
|
| 849 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 850 |
+
if not WRITE_DQ:
|
| 851 |
+
idx_b = off_z
|
| 852 |
+
idx_h = off_hq
|
| 853 |
+
idx_m = m
|
| 854 |
+
idx_n = n
|
| 855 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 856 |
+
|
| 857 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 858 |
+
dsT = grad_scores
|
| 859 |
+
if not IS_FULL_BLOCKS:
|
| 860 |
+
# (grads) apply mask for partially unmasked block
|
| 861 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 862 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 863 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 864 |
+
|
| 865 |
+
return dk, dv
|
| 866 |
+
|
| 867 |
+
# Utility triton funcs
|
| 868 |
+
@triton.jit
|
| 869 |
+
def get_offset_for_next_block(
|
| 870 |
+
loop_iter, col_indices, total_blocks,
|
| 871 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 872 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 873 |
+
):
|
| 874 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 875 |
+
return BLOCK
|
| 876 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 877 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 878 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 879 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 880 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 881 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 882 |
+
return offset
|
| 883 |
+
|
| 884 |
+
@triton.jit
|
| 885 |
+
def get_bounded_indices(indices, max_len=None):
|
| 886 |
+
return indices % max_len if max_len is not None else indices
|
| 887 |
+
|
| 888 |
+
@triton.jit
|
| 889 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 890 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 891 |
+
return tl.load(block_ptr)
|
| 892 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 893 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 894 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 895 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 896 |
+
else:
|
| 897 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 898 |
+
|
| 899 |
+
@triton.jit
|
| 900 |
+
def load_checked_2d(
|
| 901 |
+
ptr,
|
| 902 |
+
offs_m,
|
| 903 |
+
offs_n,
|
| 904 |
+
stride_m,
|
| 905 |
+
stride_n,
|
| 906 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 907 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 908 |
+
M_LEN: tl.constexpr,
|
| 909 |
+
N_LEN: tl.constexpr,
|
| 910 |
+
):
|
| 911 |
+
# Calculate final pointer if strides are provided
|
| 912 |
+
if stride_m is not None and stride_n is not None:
|
| 913 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 914 |
+
|
| 915 |
+
# Handle all masking cases
|
| 916 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 917 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 918 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 919 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 920 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 921 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 922 |
+
else: # Both divisible
|
| 923 |
+
return tl.load(ptr)
|
| 924 |
+
''', device_str='cuda')
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
async_compile.wait(globals())
|
| 928 |
+
del async_compile
|
| 929 |
+
|
| 930 |
+
class Runner:
|
| 931 |
+
def __init__(self, partitions):
|
| 932 |
+
self.partitions = partitions
|
| 933 |
+
|
| 934 |
+
def recursively_apply_fns(self, fns):
|
| 935 |
+
new_callables = []
|
| 936 |
+
for fn, c in zip(fns, self.partitions):
|
| 937 |
+
new_callables.append(fn(c))
|
| 938 |
+
self.partitions = new_callables
|
| 939 |
+
|
| 940 |
+
def call(self, args):
|
| 941 |
+
primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2 = args
|
| 942 |
+
args.clear()
|
| 943 |
+
s37 = primals_10
|
| 944 |
+
s0 = primals_11
|
| 945 |
+
s22 = primals_7
|
| 946 |
+
s72 = primals_8
|
| 947 |
+
s99 = primals_12
|
| 948 |
+
s94 = primals_14
|
| 949 |
+
s28 = primals_16
|
| 950 |
+
s4 = primals_17
|
| 951 |
+
s56 = primals_19
|
| 952 |
+
s53 = primals_22
|
| 953 |
+
s84 = primals_21
|
| 954 |
+
s100 = primals_24
|
| 955 |
+
s10 = primals_27
|
| 956 |
+
s5 = primals_26
|
| 957 |
+
assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 958 |
+
assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 959 |
+
assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 960 |
+
assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1))
|
| 961 |
+
assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1))
|
| 962 |
+
assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1))
|
| 963 |
+
assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1))
|
| 964 |
+
assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1))
|
| 965 |
+
assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1))
|
| 966 |
+
assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1))
|
| 967 |
+
assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1))
|
| 968 |
+
assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 969 |
+
assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
|
| 970 |
+
assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1))
|
| 971 |
+
assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
|
| 972 |
+
with torch.cuda._DeviceGuard(4):
|
| 973 |
+
torch.cuda.set_device(4)
|
| 974 |
+
buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 975 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 976 |
+
triton_red_fused_mul_0_xnumel = 32*s37
|
| 977 |
+
stream4 = get_raw_stream(4)
|
| 978 |
+
triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream4)
|
| 979 |
+
del getitem
|
| 980 |
+
del tangents_2
|
| 981 |
+
buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
|
| 982 |
+
buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
|
| 983 |
+
buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
|
| 984 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 985 |
+
stream4 = get_raw_stream(4)
|
| 986 |
+
triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_20, primals_23, primals_15, primals_18, primals_25, primals_28, buf5, s37, s0, s99, s22, s72, s56, s53, s84, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream4)
|
| 987 |
+
del buf1
|
| 988 |
+
del getitem_1
|
| 989 |
+
del primals_13
|
| 990 |
+
del primals_15
|
| 991 |
+
del primals_18
|
| 992 |
+
del primals_2
|
| 993 |
+
del primals_20
|
| 994 |
+
del primals_23
|
| 995 |
+
del primals_25
|
| 996 |
+
del primals_28
|
| 997 |
+
del primals_4
|
| 998 |
+
del primals_6
|
| 999 |
+
del primals_9
|
| 1000 |
+
del tangents_1
|
| 1001 |
+
return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, )
|
| 1002 |
+
|
| 1003 |
+
runner = Runner(partitions=[])
|
| 1004 |
+
call = runner.call
|
| 1005 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 1006 |
+
|
| 1007 |
+
|
| 1008 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 1009 |
+
from torch._dynamo.testing import rand_strided
|
| 1010 |
+
from torch._inductor.utils import print_performance
|
| 1011 |
+
primals_10 = 960
|
| 1012 |
+
primals_11 = 960
|
| 1013 |
+
primals_7 = 8
|
| 1014 |
+
primals_8 = 8
|
| 1015 |
+
primals_12 = 8
|
| 1016 |
+
primals_14 = 8
|
| 1017 |
+
primals_16 = 8
|
| 1018 |
+
primals_17 = 8
|
| 1019 |
+
primals_19 = 8
|
| 1020 |
+
primals_22 = 8
|
| 1021 |
+
primals_21 = 8
|
| 1022 |
+
primals_24 = 8
|
| 1023 |
+
primals_27 = 8
|
| 1024 |
+
primals_26 = 8
|
| 1025 |
+
primals_2 = rand_strided((1, 32, 960, 128), (3932160, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16)
|
| 1026 |
+
primals_4 = rand_strided((1, 8, 960, 128), (983040, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16)
|
| 1027 |
+
primals_6 = rand_strided((1, 8, 960, 128), (983040, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16)
|
| 1028 |
+
primals_9 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32)
|
| 1029 |
+
primals_13 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32)
|
| 1030 |
+
primals_15 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32)
|
| 1031 |
+
primals_18 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32)
|
| 1032 |
+
primals_20 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32)
|
| 1033 |
+
primals_23 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32)
|
| 1034 |
+
primals_25 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32)
|
| 1035 |
+
primals_28 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32)
|
| 1036 |
+
getitem = rand_strided((1, 32, 960, 128), (3932160, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16)
|
| 1037 |
+
getitem_1 = rand_strided((1, 32, 960), (30720, 960, 1), device='cuda:4', dtype=torch.float32)
|
| 1038 |
+
tangents_1 = rand_strided((1, 32, 960, 128), (3932160, 122880, 128, 1), device='cuda:4', dtype=torch.bfloat16)
|
| 1039 |
+
tangents_2 = rand_strided((1, 32, 960), (30720, 960, 1), device='cuda:4', dtype=torch.float32)
|
| 1040 |
+
fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2])
|
| 1041 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 1042 |
+
|
| 1043 |
+
|
| 1044 |
+
if __name__ == "__main__":
|
| 1045 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 1046 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/2z/c2zqq6qyjomc7iflknbqr7yjdhjux47hzv4nnsi5qfbeqglaip2h.py
ADDED
|
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['4_forward']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/cf/ccftanvnrini6kruughcnjtpfiarn7zwa2sdotthpo3wbbjituv3.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# flex_attention => flex_attention
|
| 41 |
+
# Graph fragment:
|
| 42 |
+
# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_2]
|
| 43 |
+
# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_4]
|
| 44 |
+
# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_6]
|
| 45 |
+
# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=getitem_1]
|
| 46 |
+
# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1]
|
| 47 |
+
# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=primals_10]
|
| 48 |
+
# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=primals_7]
|
| 49 |
+
# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=primals_11]
|
| 50 |
+
# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=primals_12]
|
| 51 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 52 |
+
# return %getitem
|
| 53 |
+
triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
|
| 54 |
+
import triton
|
| 55 |
+
import triton.language as tl
|
| 56 |
+
|
| 57 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 58 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 59 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 60 |
+
|
| 61 |
+
@triton_heuristics.template(
|
| 62 |
+
|
| 63 |
+
num_stages=3,
|
| 64 |
+
num_warps=8,
|
| 65 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 66 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 67 |
+
|
| 68 |
+
)
|
| 69 |
+
@triton.jit
|
| 70 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
|
| 71 |
+
PRESCALE_QK : tl.constexpr = False
|
| 72 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 73 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 74 |
+
WRITE_DQ : tl.constexpr = True
|
| 75 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 76 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 77 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 78 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 79 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 80 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 81 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 82 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 83 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 84 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 85 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 86 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 87 |
+
USE_TMA : tl.constexpr = False
|
| 88 |
+
BLOCK_M : tl.constexpr = 128
|
| 89 |
+
BLOCK_N : tl.constexpr = 64
|
| 90 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 91 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 92 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 93 |
+
Q = arg_Q
|
| 94 |
+
K = arg_K
|
| 95 |
+
V = arg_V
|
| 96 |
+
LSE = arg_LSE
|
| 97 |
+
MAX = arg_MAX
|
| 98 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 99 |
+
KV_IDX = arg_KV_IDX
|
| 100 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 101 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 102 |
+
|
| 103 |
+
# Sub notation for this kernel:
|
| 104 |
+
#
|
| 105 |
+
# Q: Query, K: Key, V: Value
|
| 106 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 107 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 108 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 109 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 110 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 111 |
+
#
|
| 112 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 113 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 114 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 115 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 116 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 117 |
+
#
|
| 118 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 119 |
+
#
|
| 120 |
+
# (Modifiable) Performance tuning options
|
| 121 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 122 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 123 |
+
|
| 124 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 125 |
+
# or involve a numerics vs. perf tradeoff
|
| 126 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 127 |
+
# about 20% more numerical error, but slightly faster.
|
| 128 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 129 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 130 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 131 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 132 |
+
|
| 133 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 134 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 135 |
+
|
| 136 |
+
# Define strides of inputs
|
| 137 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
|
| 138 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 139 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
|
| 140 |
+
|
| 141 |
+
ZQ = 1
|
| 142 |
+
HQ = 32
|
| 143 |
+
Q_LEN = ks0
|
| 144 |
+
ZKV = 1
|
| 145 |
+
KV_LEN = ks1
|
| 146 |
+
|
| 147 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 148 |
+
|
| 149 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 150 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 151 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 152 |
+
|
| 153 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 154 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 155 |
+
off_zkv = off_zq % ZKV
|
| 156 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 157 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 158 |
+
|
| 159 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 160 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 161 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 162 |
+
|
| 163 |
+
Q = Q + q_offset
|
| 164 |
+
K = K + k_offset
|
| 165 |
+
V = V + v_offset
|
| 166 |
+
|
| 167 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 168 |
+
desc_q = None
|
| 169 |
+
desc_k = None
|
| 170 |
+
desc_v = None
|
| 171 |
+
|
| 172 |
+
SPARSE_Z = 1
|
| 173 |
+
SPARSE_HQ = 1
|
| 174 |
+
|
| 175 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 176 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 177 |
+
|
| 178 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 179 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 180 |
+
|
| 181 |
+
stride_kv_num_blks_h = 1
|
| 182 |
+
stride_kv_idx_h = 1
|
| 183 |
+
stride_kv_idx_m = 1
|
| 184 |
+
|
| 185 |
+
# initialize pointer to m and l
|
| 186 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 187 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 188 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 189 |
+
|
| 190 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 191 |
+
|
| 192 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 193 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 194 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 195 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 196 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 197 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 198 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 199 |
+
|
| 200 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 201 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 202 |
+
# both score_mod and mask_mod to it
|
| 203 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 204 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 205 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 206 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# K and V pointers will be passed directly to forward_inner
|
| 210 |
+
|
| 211 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
acc, l_i, m_i = forward_inner(
|
| 215 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 216 |
+
q, K, V,
|
| 217 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 218 |
+
acc, l_i, m_i,
|
| 219 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 220 |
+
kv_start,
|
| 221 |
+
kv_indices, kv_num_blocks,
|
| 222 |
+
0, block_n_end,
|
| 223 |
+
MATMUL_PRECISION,
|
| 224 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 225 |
+
IS_FULL_BLOCKS=False,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 229 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 230 |
+
# apply mask_mod to them - only score_mod
|
| 231 |
+
if HAS_FULL_BLOCKS:
|
| 232 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 233 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 234 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 235 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 236 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 237 |
+
# K and V pointers will be passed directly to forward_inner
|
| 238 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 239 |
+
|
| 240 |
+
acc, l_i, m_i = forward_inner(
|
| 241 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 242 |
+
q, K, V,
|
| 243 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 244 |
+
acc, l_i, m_i,
|
| 245 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 246 |
+
kv_start,
|
| 247 |
+
kv_indices, kv_num_blocks,
|
| 248 |
+
0, block_n_end,
|
| 249 |
+
MATMUL_PRECISION,
|
| 250 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 251 |
+
IS_FULL_BLOCKS=True,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# [Note] Handle fully masked out rows:
|
| 256 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 257 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 258 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 259 |
+
|
| 260 |
+
acc = acc / l_i[:, None]
|
| 261 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 262 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 263 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 264 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 265 |
+
|
| 266 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 267 |
+
|
| 268 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 269 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
|
| 270 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 271 |
+
|
| 272 |
+
if OUTPUT_LOGSUMEXP:
|
| 273 |
+
off_hz = off_zq * HQ + off_hq
|
| 274 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 275 |
+
lse = m_i + tl.math.log2(l_i)
|
| 276 |
+
if IS_DIVISIBLE:
|
| 277 |
+
tl.store(l_ptrs, lse)
|
| 278 |
+
else:
|
| 279 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 280 |
+
|
| 281 |
+
if OUTPUT_MAX:
|
| 282 |
+
off_hz = off_zq * HQ + off_hq
|
| 283 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 284 |
+
if IS_DIVISIBLE:
|
| 285 |
+
tl.store(max_ptrs, m_i)
|
| 286 |
+
else:
|
| 287 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# Utility triton funcs
|
| 291 |
+
@triton.jit
|
| 292 |
+
def get_offset_for_next_block(
|
| 293 |
+
loop_iter, col_indices, total_blocks,
|
| 294 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 295 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 296 |
+
):
|
| 297 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 298 |
+
return BLOCK
|
| 299 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 300 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 301 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 302 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 303 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 304 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 305 |
+
return offset
|
| 306 |
+
|
| 307 |
+
@triton.jit
|
| 308 |
+
def get_bounded_indices(indices, max_len=None):
|
| 309 |
+
return indices % max_len if max_len is not None else indices
|
| 310 |
+
|
| 311 |
+
@triton.jit
|
| 312 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 313 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 314 |
+
return tl.load(block_ptr)
|
| 315 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 316 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 317 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 318 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 319 |
+
else:
|
| 320 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 321 |
+
|
| 322 |
+
@triton.jit
|
| 323 |
+
def load_checked_2d(
|
| 324 |
+
ptr,
|
| 325 |
+
offs_m,
|
| 326 |
+
offs_n,
|
| 327 |
+
stride_m,
|
| 328 |
+
stride_n,
|
| 329 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 330 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 331 |
+
M_LEN: tl.constexpr,
|
| 332 |
+
N_LEN: tl.constexpr,
|
| 333 |
+
):
|
| 334 |
+
# Calculate final pointer if strides are provided
|
| 335 |
+
if stride_m is not None and stride_n is not None:
|
| 336 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 337 |
+
|
| 338 |
+
# Handle all masking cases
|
| 339 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 340 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 341 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 342 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 343 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 344 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 345 |
+
else: # Both divisible
|
| 346 |
+
return tl.load(ptr)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# Common Imports
|
| 350 |
+
@triton.jit
|
| 351 |
+
def forward_block_mn(
|
| 352 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 353 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 354 |
+
# accumulated values
|
| 355 |
+
acc, l_i, m_i,
|
| 356 |
+
# Offsets
|
| 357 |
+
off_z, off_h, offs_m, offs_n,
|
| 358 |
+
# Offsets needed for TMA loads
|
| 359 |
+
kv_start,
|
| 360 |
+
kv_offset,
|
| 361 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 362 |
+
# Strides for K and V
|
| 363 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 364 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 365 |
+
|
| 366 |
+
):
|
| 367 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 368 |
+
PRESCALE_QK : tl.constexpr = False
|
| 369 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 370 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 371 |
+
WRITE_DQ : tl.constexpr = True
|
| 372 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 373 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 374 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 375 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 376 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 377 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 378 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 379 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 380 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 381 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 382 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 383 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 384 |
+
USE_TMA : tl.constexpr = False
|
| 385 |
+
BLOCK_M : tl.constexpr = 128
|
| 386 |
+
BLOCK_N : tl.constexpr = 64
|
| 387 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 388 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 389 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# -- load k --
|
| 393 |
+
# NB reversed order to since K is transposed
|
| 394 |
+
kv_base_offset = kv_start + kv_offset
|
| 395 |
+
|
| 396 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 397 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 398 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 399 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 400 |
+
|
| 401 |
+
k = tl.trans(k)
|
| 402 |
+
# -- compute qk ---
|
| 403 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 404 |
+
if not PRESCALE_QK:
|
| 405 |
+
qk *= SM_SCALE
|
| 406 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 407 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 408 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 409 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 410 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 411 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 412 |
+
|
| 413 |
+
tmp0 = (qk)
|
| 414 |
+
post_mod_scores = tmp0
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 418 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 419 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 420 |
+
|
| 421 |
+
if not IS_FULL_BLOCKS:
|
| 422 |
+
tmp1 = (m)
|
| 423 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 424 |
+
tmp3 = tmp1 < tmp2
|
| 425 |
+
tmp4 = (n)
|
| 426 |
+
tmp5 = tmp4 <= tmp1
|
| 427 |
+
tmp6 = tmp3 & tmp5
|
| 428 |
+
tmp7 = tmp1 >= tmp2
|
| 429 |
+
tmp8 = tmp4 < tmp2
|
| 430 |
+
tmp9 = tmp7 & tmp8
|
| 431 |
+
tmp10 = tmp8 == 0
|
| 432 |
+
tmp11 = tmp7 & tmp10
|
| 433 |
+
tmp12 = tmp1 - tmp2
|
| 434 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 435 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 436 |
+
tmp15 = tmp4 - tmp2
|
| 437 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 438 |
+
tmp17 = tmp14 == tmp16
|
| 439 |
+
tmp18 = tmp11 & tmp17
|
| 440 |
+
tmp19 = tmp9 | tmp18
|
| 441 |
+
tmp20 = tmp6 | tmp19
|
| 442 |
+
mask_mod_output = tmp20
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 446 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 447 |
+
# apply mask for partially unmasked blocks
|
| 448 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 449 |
+
|
| 450 |
+
if not PRESCALE_QK:
|
| 451 |
+
post_mod_scores *= RCP_LN2
|
| 452 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 453 |
+
|
| 454 |
+
# -- compute scaling constant ---
|
| 455 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 456 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 457 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 458 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 459 |
+
else:
|
| 460 |
+
m_ij_masked = m_ij
|
| 461 |
+
|
| 462 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 463 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 464 |
+
|
| 465 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 466 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 467 |
+
# m_ij
|
| 468 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 469 |
+
# # -- scale and update acc --
|
| 470 |
+
acc = acc * alpha[:, None]
|
| 471 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 472 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 473 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 474 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 475 |
+
|
| 476 |
+
# -- update m_i
|
| 477 |
+
m_i = m_ij
|
| 478 |
+
|
| 479 |
+
return acc, l_i, m_i
|
| 480 |
+
|
| 481 |
+
@triton.jit
|
| 482 |
+
def forward_inner(
|
| 483 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 484 |
+
q, K, V,
|
| 485 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 486 |
+
# accumulated values
|
| 487 |
+
acc, l_i, m_i,
|
| 488 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 489 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 490 |
+
off_z, off_h, offs_m, offs_n,
|
| 491 |
+
# Offsets needed for TMA loads
|
| 492 |
+
kv_start,
|
| 493 |
+
# blocksparse data
|
| 494 |
+
kv_indices, kv_num_blocks,
|
| 495 |
+
# start kv and end kv block
|
| 496 |
+
block_n_start, block_n_end,
|
| 497 |
+
MATMUL_PRECISION,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
):
|
| 502 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 503 |
+
PRESCALE_QK : tl.constexpr = False
|
| 504 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 505 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 506 |
+
WRITE_DQ : tl.constexpr = True
|
| 507 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 508 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 509 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 510 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 511 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 512 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 513 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 514 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 515 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 516 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 517 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 518 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 519 |
+
USE_TMA : tl.constexpr = False
|
| 520 |
+
BLOCK_M : tl.constexpr = 128
|
| 521 |
+
BLOCK_N : tl.constexpr = 64
|
| 522 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 523 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 524 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 528 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 529 |
+
|
| 530 |
+
if PRESCALE_QK:
|
| 531 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 532 |
+
|
| 533 |
+
kv_offset = 0
|
| 534 |
+
|
| 535 |
+
# loop over k, v and update accumulator until block_n_end
|
| 536 |
+
for start_n in range(block_n_start, block_n_end):
|
| 537 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 538 |
+
if IS_DIVISIBLE:
|
| 539 |
+
acc, l_i, m_i = forward_block_mn(
|
| 540 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 541 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 542 |
+
# accumulated values
|
| 543 |
+
acc, l_i, m_i,
|
| 544 |
+
# Offsets
|
| 545 |
+
off_z, off_h, offs_m, offs_n,
|
| 546 |
+
# Offsets needed for TMA loads
|
| 547 |
+
kv_start,
|
| 548 |
+
kv_offset,
|
| 549 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 550 |
+
# Strides for K and V
|
| 551 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 552 |
+
IS_FULL_BLOCKS,
|
| 553 |
+
)
|
| 554 |
+
else:
|
| 555 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 556 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 557 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 558 |
+
# to the last block because it's faster a lot.
|
| 559 |
+
acc, l_i, m_i = forward_block_mn(
|
| 560 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 561 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 562 |
+
# accumulated values
|
| 563 |
+
acc, l_i, m_i,
|
| 564 |
+
# Offsets
|
| 565 |
+
off_z, off_h, offs_m, offs_n,
|
| 566 |
+
# Offsets needed for TMA loads
|
| 567 |
+
kv_start,
|
| 568 |
+
kv_offset,
|
| 569 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 570 |
+
# Strides for K and V
|
| 571 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 572 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
offset = get_offset_for_next_block(
|
| 578 |
+
start_n, kv_indices, kv_num_blocks,
|
| 579 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
offs_n = offs_n + offset
|
| 583 |
+
kv_offset += offset
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
return acc, l_i, m_i
|
| 587 |
+
''', device_str='cuda')
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/uu/cuu2rr2yygwarlbfvcbucg7erbfsky4wxudbfsdny5wzgxewg4ut.py
|
| 591 |
+
# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
|
| 592 |
+
# Source node to ATen node mapping:
|
| 593 |
+
# lse_scaled => mul_15
|
| 594 |
+
# Graph fragment:
|
| 595 |
+
# %buf3 : Tensor = PlaceHolder[target=buf3]
|
| 596 |
+
# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
|
| 597 |
+
# return %mul_15
|
| 598 |
+
triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', '''
|
| 599 |
+
import triton
|
| 600 |
+
import triton.language as tl
|
| 601 |
+
|
| 602 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 603 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 604 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 605 |
+
triton_helpers.set_driver_to_gpu()
|
| 606 |
+
|
| 607 |
+
@triton_heuristics.pointwise(
|
| 608 |
+
size_hints={'x': 4096},
|
| 609 |
+
filename=__file__,
|
| 610 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 611 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 612 |
+
min_elem_per_thread=0
|
| 613 |
+
)
|
| 614 |
+
@triton.jit
|
| 615 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 616 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 617 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 618 |
+
xmask = xindex < xnumel
|
| 619 |
+
x2 = xindex
|
| 620 |
+
x0 = (xindex % ks0)
|
| 621 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 622 |
+
tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
|
| 623 |
+
tmp1 = 0.6931471805599453
|
| 624 |
+
tmp2 = tmp0 * tmp1
|
| 625 |
+
tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
|
| 626 |
+
''', device_str='cuda')
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
async_compile.wait(globals())
|
| 630 |
+
del async_compile
|
| 631 |
+
|
| 632 |
+
class Runner:
|
| 633 |
+
def __init__(self, partitions):
|
| 634 |
+
self.partitions = partitions
|
| 635 |
+
|
| 636 |
+
def recursively_apply_fns(self, fns):
|
| 637 |
+
new_callables = []
|
| 638 |
+
for fn, c in zip(fns, self.partitions):
|
| 639 |
+
new_callables.append(fn(c))
|
| 640 |
+
self.partitions = new_callables
|
| 641 |
+
|
| 642 |
+
def call(self, args):
|
| 643 |
+
primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args
|
| 644 |
+
args.clear()
|
| 645 |
+
s50 = primals_1
|
| 646 |
+
s0 = primals_3
|
| 647 |
+
s43 = primals_5
|
| 648 |
+
s37 = primals_8
|
| 649 |
+
s71 = primals_9
|
| 650 |
+
assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 651 |
+
assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 652 |
+
assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 653 |
+
assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 654 |
+
assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1))
|
| 655 |
+
assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1))
|
| 656 |
+
assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 657 |
+
assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1))
|
| 658 |
+
assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 659 |
+
assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1))
|
| 660 |
+
assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 661 |
+
with torch.cuda._DeviceGuard(0):
|
| 662 |
+
torch.cuda.set_device(0)
|
| 663 |
+
buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 664 |
+
buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 665 |
+
buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
|
| 666 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 667 |
+
stream0 = get_raw_stream(0)
|
| 668 |
+
triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, (127 + s37) // 128, 1, 32, stream=stream0)
|
| 669 |
+
del buf1
|
| 670 |
+
buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
|
| 671 |
+
# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
|
| 672 |
+
triton_poi_fused_mul_1_xnumel = 32*s37
|
| 673 |
+
stream0 = get_raw_stream(0)
|
| 674 |
+
triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream0)
|
| 675 |
+
return (buf2, buf5, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf2, buf0, s37, s0, )
|
| 676 |
+
|
| 677 |
+
runner = Runner(partitions=[])
|
| 678 |
+
call = runner.call
|
| 679 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 683 |
+
from torch._dynamo.testing import rand_strided
|
| 684 |
+
from torch._inductor.utils import print_performance
|
| 685 |
+
primals_1 = 128
|
| 686 |
+
primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 687 |
+
primals_3 = 128
|
| 688 |
+
primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 689 |
+
primals_5 = 128
|
| 690 |
+
primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 691 |
+
primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32)
|
| 692 |
+
primals_8 = 128
|
| 693 |
+
primals_9 = 128
|
| 694 |
+
primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32)
|
| 695 |
+
primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32)
|
| 696 |
+
primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32)
|
| 697 |
+
primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32)
|
| 698 |
+
primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32)
|
| 699 |
+
primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32)
|
| 700 |
+
primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32)
|
| 701 |
+
fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16])
|
| 702 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
if __name__ == "__main__":
|
| 706 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 707 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/32/8d96bbe05a966b7e7756831f09a79e31bf46fad0952af86f36d75557fc1735e8.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 512, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 13, "triton_cache_hash": "BGHEC74L2RGBNBI3A4UJOTHXFUUKS4KY3KJKVN65FHLWR47O6USQ"}
|
progress/SpecForge/cache/compiled_kernels/32/c32pbcuz72bjfnkzvckfbbzlzuupc5yxl7t47b3qf74mmk5g2d2z.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 65536},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 282624}},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 35328
|
| 20 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 21 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 22 |
+
xmask = xindex < xnumel
|
| 23 |
+
x0 = xindex
|
| 24 |
+
tmp0 = tl.load(in_ptr0 + (x0), xmask)
|
| 25 |
+
tmp1 = 0.6931471805599453
|
| 26 |
+
tmp2 = tmp0 * tmp1
|
| 27 |
+
tl.store(out_ptr0 + (x0), tmp2, xmask)
|
progress/SpecForge/cache/compiled_kernels/3b/a0a6b043ab548fdf71e72bbdf5daab7f72e9ed11a9ad9f8824a6263bb6bc5081.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 512, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 14, "triton_cache_hash": "DSCNRRQHW6TSEFKL6AMK6FYZWMIHBTRCG2BE5YK5T7Q76TMOZ5HQ"}
|
progress/SpecForge/cache/compiled_kernels/3b/c3bqw7dk7k6dcdrp3ycrthotye7y6zb26752jl4lwmfgaybpvr6y.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 65536},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 348160}},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 43520
|
| 20 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 21 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 22 |
+
xmask = xindex < xnumel
|
| 23 |
+
x0 = xindex
|
| 24 |
+
tmp0 = tl.load(in_ptr0 + (x0), xmask)
|
| 25 |
+
tmp1 = 0.6931471805599453
|
| 26 |
+
tmp2 = tmp0 * tmp1
|
| 27 |
+
tl.store(out_ptr0 + (x0), tmp2, xmask)
|
progress/SpecForge/cache/compiled_kernels/3f/3f6057605b157d44fd56f748226a63975b79198f94871188e73e46cd6c7f8792.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 60, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"}
|
progress/SpecForge/cache/compiled_kernels/3f/c3fttv7enp2yvnla3r6jkk4galt2qdpxw577ghvkmmx6zqaqla74.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.persistent_reduction(
|
| 11 |
+
size_hints={'x': 524288, 'r0_': 32},
|
| 12 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 19 |
+
r0_numel = 32
|
| 20 |
+
R0_BLOCK: tl.constexpr = 32
|
| 21 |
+
rnumel = r0_numel
|
| 22 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 23 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 24 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 25 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 26 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 27 |
+
r0_offset = 0
|
| 28 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 29 |
+
roffset = r0_offset
|
| 30 |
+
rindex = r0_index
|
| 31 |
+
r0_2 = r0_index
|
| 32 |
+
x5 = xindex
|
| 33 |
+
x1 = xindex // 128
|
| 34 |
+
x0 = (xindex % 128)
|
| 35 |
+
x3 = ((xindex // 128) % ks0)
|
| 36 |
+
x4 = xindex // ks1
|
| 37 |
+
tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None)
|
| 38 |
+
tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
|
| 39 |
+
tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last')
|
| 40 |
+
tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last')
|
| 41 |
+
tmp2 = float("-inf")
|
| 42 |
+
tmp3 = tmp1 == tmp2
|
| 43 |
+
tmp5 = tmp4 - tmp1
|
| 44 |
+
tmp6 = 0.0
|
| 45 |
+
tmp7 = tl.where(tmp3, tmp6, tmp5)
|
| 46 |
+
tmp8 = libdevice.exp2(tmp7)
|
| 47 |
+
tmp9 = tmp0 * tmp8
|
| 48 |
+
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
|
| 49 |
+
tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
|
| 50 |
+
tmp14 = 1.0
|
| 51 |
+
tmp15 = tl.where(tmp3, tmp14, tmp13)
|
| 52 |
+
tmp16 = (tmp12 / tmp15)
|
| 53 |
+
tmp17 = tmp16.to(tl.float32)
|
| 54 |
+
tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None)
|
progress/SpecForge/cache/compiled_kernels/3n/c3nlaqknekmjv2zuxzow4rf42v3gorxnfp6uod3dg3ic5ibp6yp3.py
ADDED
|
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['1_inference']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl76p6rje3cyrrbyvxjjj7oxbieltfs4p5xqjre35l6wnofhynby.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# flex_attention => flex_attention
|
| 41 |
+
# Graph fragment:
|
| 42 |
+
# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=arg1_1]
|
| 43 |
+
# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg3_1]
|
| 44 |
+
# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg5_1]
|
| 45 |
+
# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=getitem_1]
|
| 46 |
+
# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf1]
|
| 47 |
+
# %arg9_1 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:1" = PlaceHolder[target=arg9_1]
|
| 48 |
+
# %arg6_1 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:1" = PlaceHolder[target=arg6_1]
|
| 49 |
+
# %arg10_1 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:1" = PlaceHolder[target=arg10_1]
|
| 50 |
+
# %arg11_1 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:1" = PlaceHolder[target=arg11_1]
|
| 51 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 52 |
+
# return %getitem
|
| 53 |
+
triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
|
| 54 |
+
import triton
|
| 55 |
+
import triton.language as tl
|
| 56 |
+
|
| 57 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 58 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 59 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 60 |
+
|
| 61 |
+
@triton_heuristics.template(
|
| 62 |
+
|
| 63 |
+
num_stages=3,
|
| 64 |
+
num_warps=8,
|
| 65 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 66 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 67 |
+
|
| 68 |
+
)
|
| 69 |
+
@triton.jit
|
| 70 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
|
| 71 |
+
PRESCALE_QK : tl.constexpr = False
|
| 72 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 73 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 74 |
+
WRITE_DQ : tl.constexpr = True
|
| 75 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 76 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 77 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 78 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 79 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 80 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 81 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 82 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 83 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 84 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 85 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 86 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 87 |
+
USE_TMA : tl.constexpr = False
|
| 88 |
+
BLOCK_M : tl.constexpr = 128
|
| 89 |
+
BLOCK_N : tl.constexpr = 64
|
| 90 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 91 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 92 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 93 |
+
Q = arg_Q
|
| 94 |
+
K = arg_K
|
| 95 |
+
V = arg_V
|
| 96 |
+
LSE = arg_LSE
|
| 97 |
+
MAX = arg_MAX
|
| 98 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 99 |
+
KV_IDX = arg_KV_IDX
|
| 100 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 101 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 102 |
+
|
| 103 |
+
# Sub notation for this kernel:
|
| 104 |
+
#
|
| 105 |
+
# Q: Query, K: Key, V: Value
|
| 106 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 107 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 108 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 109 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 110 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 111 |
+
#
|
| 112 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 113 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 114 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 115 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 116 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 117 |
+
#
|
| 118 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 119 |
+
#
|
| 120 |
+
# (Modifiable) Performance tuning options
|
| 121 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 122 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 123 |
+
|
| 124 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 125 |
+
# or involve a numerics vs. perf tradeoff
|
| 126 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 127 |
+
# about 20% more numerical error, but slightly faster.
|
| 128 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 129 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 130 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 131 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 132 |
+
|
| 133 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 134 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 135 |
+
|
| 136 |
+
# Define strides of inputs
|
| 137 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
|
| 138 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 139 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
|
| 140 |
+
|
| 141 |
+
ZQ = 1
|
| 142 |
+
HQ = 32
|
| 143 |
+
Q_LEN = ks0
|
| 144 |
+
ZKV = 1
|
| 145 |
+
KV_LEN = ks1
|
| 146 |
+
|
| 147 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 148 |
+
|
| 149 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 150 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 151 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 152 |
+
|
| 153 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 154 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 155 |
+
off_zkv = off_zq % ZKV
|
| 156 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 157 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 158 |
+
|
| 159 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 160 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 161 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 162 |
+
|
| 163 |
+
Q = Q + q_offset
|
| 164 |
+
K = K + k_offset
|
| 165 |
+
V = V + v_offset
|
| 166 |
+
|
| 167 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 168 |
+
desc_q = None
|
| 169 |
+
desc_k = None
|
| 170 |
+
desc_v = None
|
| 171 |
+
|
| 172 |
+
SPARSE_Z = 1
|
| 173 |
+
SPARSE_HQ = 1
|
| 174 |
+
|
| 175 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 176 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 177 |
+
|
| 178 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 179 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 180 |
+
|
| 181 |
+
stride_kv_num_blks_h = 5
|
| 182 |
+
stride_kv_idx_h = 25
|
| 183 |
+
stride_kv_idx_m = 5
|
| 184 |
+
|
| 185 |
+
# initialize pointer to m and l
|
| 186 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 187 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 188 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 189 |
+
|
| 190 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 191 |
+
|
| 192 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 193 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 194 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 195 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 196 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 197 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 198 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 199 |
+
|
| 200 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 201 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 202 |
+
# both score_mod and mask_mod to it
|
| 203 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 204 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 205 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 206 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# K and V pointers will be passed directly to forward_inner
|
| 210 |
+
|
| 211 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
acc, l_i, m_i = forward_inner(
|
| 215 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 216 |
+
q, K, V,
|
| 217 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 218 |
+
acc, l_i, m_i,
|
| 219 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 220 |
+
kv_start,
|
| 221 |
+
kv_indices, kv_num_blocks,
|
| 222 |
+
0, block_n_end,
|
| 223 |
+
MATMUL_PRECISION,
|
| 224 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 225 |
+
IS_FULL_BLOCKS=False,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 229 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 230 |
+
# apply mask_mod to them - only score_mod
|
| 231 |
+
if HAS_FULL_BLOCKS:
|
| 232 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 233 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 234 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 235 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 236 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 237 |
+
# K and V pointers will be passed directly to forward_inner
|
| 238 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 239 |
+
|
| 240 |
+
acc, l_i, m_i = forward_inner(
|
| 241 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 242 |
+
q, K, V,
|
| 243 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 244 |
+
acc, l_i, m_i,
|
| 245 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 246 |
+
kv_start,
|
| 247 |
+
kv_indices, kv_num_blocks,
|
| 248 |
+
0, block_n_end,
|
| 249 |
+
MATMUL_PRECISION,
|
| 250 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 251 |
+
IS_FULL_BLOCKS=True,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# [Note] Handle fully masked out rows:
|
| 256 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 257 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 258 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 259 |
+
|
| 260 |
+
acc = acc / l_i[:, None]
|
| 261 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 262 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 263 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 264 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 265 |
+
|
| 266 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 267 |
+
|
| 268 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 269 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
|
| 270 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 271 |
+
|
| 272 |
+
if OUTPUT_LOGSUMEXP:
|
| 273 |
+
off_hz = off_zq * HQ + off_hq
|
| 274 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 275 |
+
lse = m_i + tl.math.log2(l_i)
|
| 276 |
+
if IS_DIVISIBLE:
|
| 277 |
+
tl.store(l_ptrs, lse)
|
| 278 |
+
else:
|
| 279 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 280 |
+
|
| 281 |
+
if OUTPUT_MAX:
|
| 282 |
+
off_hz = off_zq * HQ + off_hq
|
| 283 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 284 |
+
if IS_DIVISIBLE:
|
| 285 |
+
tl.store(max_ptrs, m_i)
|
| 286 |
+
else:
|
| 287 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# Utility triton funcs
|
| 291 |
+
@triton.jit
|
| 292 |
+
def get_offset_for_next_block(
|
| 293 |
+
loop_iter, col_indices, total_blocks,
|
| 294 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 295 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 296 |
+
):
|
| 297 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 298 |
+
return BLOCK
|
| 299 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 300 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 301 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 302 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 303 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 304 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 305 |
+
return offset
|
| 306 |
+
|
| 307 |
+
@triton.jit
|
| 308 |
+
def get_bounded_indices(indices, max_len=None):
|
| 309 |
+
return indices % max_len if max_len is not None else indices
|
| 310 |
+
|
| 311 |
+
@triton.jit
|
| 312 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 313 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 314 |
+
return tl.load(block_ptr)
|
| 315 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 316 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 317 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 318 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 319 |
+
else:
|
| 320 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 321 |
+
|
| 322 |
+
@triton.jit
|
| 323 |
+
def load_checked_2d(
|
| 324 |
+
ptr,
|
| 325 |
+
offs_m,
|
| 326 |
+
offs_n,
|
| 327 |
+
stride_m,
|
| 328 |
+
stride_n,
|
| 329 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 330 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 331 |
+
M_LEN: tl.constexpr,
|
| 332 |
+
N_LEN: tl.constexpr,
|
| 333 |
+
):
|
| 334 |
+
# Calculate final pointer if strides are provided
|
| 335 |
+
if stride_m is not None and stride_n is not None:
|
| 336 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 337 |
+
|
| 338 |
+
# Handle all masking cases
|
| 339 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 340 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 341 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 342 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 343 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 344 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 345 |
+
else: # Both divisible
|
| 346 |
+
return tl.load(ptr)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# Common Imports
|
| 350 |
+
@triton.jit
|
| 351 |
+
def forward_block_mn(
|
| 352 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 353 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 354 |
+
# accumulated values
|
| 355 |
+
acc, l_i, m_i,
|
| 356 |
+
# Offsets
|
| 357 |
+
off_z, off_h, offs_m, offs_n,
|
| 358 |
+
# Offsets needed for TMA loads
|
| 359 |
+
kv_start,
|
| 360 |
+
kv_offset,
|
| 361 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 362 |
+
# Strides for K and V
|
| 363 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 364 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 365 |
+
|
| 366 |
+
):
|
| 367 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 368 |
+
PRESCALE_QK : tl.constexpr = False
|
| 369 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 370 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 371 |
+
WRITE_DQ : tl.constexpr = True
|
| 372 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 373 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 374 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 375 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 376 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 377 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 378 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 379 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 380 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 381 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 382 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 383 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 384 |
+
USE_TMA : tl.constexpr = False
|
| 385 |
+
BLOCK_M : tl.constexpr = 128
|
| 386 |
+
BLOCK_N : tl.constexpr = 64
|
| 387 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 388 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 389 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# -- load k --
|
| 393 |
+
# NB reversed order to since K is transposed
|
| 394 |
+
kv_base_offset = kv_start + kv_offset
|
| 395 |
+
|
| 396 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 397 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 398 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 399 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 400 |
+
|
| 401 |
+
k = tl.trans(k)
|
| 402 |
+
# -- compute qk ---
|
| 403 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 404 |
+
if not PRESCALE_QK:
|
| 405 |
+
qk *= SM_SCALE
|
| 406 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 407 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 408 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 409 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 410 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 411 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 412 |
+
|
| 413 |
+
tmp0 = (qk)
|
| 414 |
+
post_mod_scores = tmp0
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 418 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 419 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 420 |
+
|
| 421 |
+
if not IS_FULL_BLOCKS:
|
| 422 |
+
tmp1 = (m)
|
| 423 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 424 |
+
tmp3 = tmp1 < tmp2
|
| 425 |
+
tmp4 = (n)
|
| 426 |
+
tmp5 = tmp4 <= tmp1
|
| 427 |
+
tmp6 = tmp3 & tmp5
|
| 428 |
+
tmp7 = tmp1 >= tmp2
|
| 429 |
+
tmp8 = tmp4 < tmp2
|
| 430 |
+
tmp9 = tmp7 & tmp8
|
| 431 |
+
tmp10 = tmp8 == 0
|
| 432 |
+
tmp11 = tmp7 & tmp10
|
| 433 |
+
tmp12 = tmp1 - tmp2
|
| 434 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 435 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 436 |
+
tmp15 = tmp4 - tmp2
|
| 437 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 438 |
+
tmp17 = tmp14 == tmp16
|
| 439 |
+
tmp18 = tmp11 & tmp17
|
| 440 |
+
tmp19 = tmp9 | tmp18
|
| 441 |
+
tmp20 = tmp6 | tmp19
|
| 442 |
+
mask_mod_output = tmp20
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 446 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 447 |
+
# apply mask for partially unmasked blocks
|
| 448 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 449 |
+
|
| 450 |
+
if not PRESCALE_QK:
|
| 451 |
+
post_mod_scores *= RCP_LN2
|
| 452 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 453 |
+
|
| 454 |
+
# -- compute scaling constant ---
|
| 455 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 456 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 457 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 458 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 459 |
+
else:
|
| 460 |
+
m_ij_masked = m_ij
|
| 461 |
+
|
| 462 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 463 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 464 |
+
|
| 465 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 466 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 467 |
+
# m_ij
|
| 468 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 469 |
+
# # -- scale and update acc --
|
| 470 |
+
acc = acc * alpha[:, None]
|
| 471 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 472 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 473 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 474 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 475 |
+
|
| 476 |
+
# -- update m_i
|
| 477 |
+
m_i = m_ij
|
| 478 |
+
|
| 479 |
+
return acc, l_i, m_i
|
| 480 |
+
|
| 481 |
+
@triton.jit
|
| 482 |
+
def forward_inner(
|
| 483 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 484 |
+
q, K, V,
|
| 485 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 486 |
+
# accumulated values
|
| 487 |
+
acc, l_i, m_i,
|
| 488 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 489 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 490 |
+
off_z, off_h, offs_m, offs_n,
|
| 491 |
+
# Offsets needed for TMA loads
|
| 492 |
+
kv_start,
|
| 493 |
+
# blocksparse data
|
| 494 |
+
kv_indices, kv_num_blocks,
|
| 495 |
+
# start kv and end kv block
|
| 496 |
+
block_n_start, block_n_end,
|
| 497 |
+
MATMUL_PRECISION,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
):
|
| 502 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 503 |
+
PRESCALE_QK : tl.constexpr = False
|
| 504 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 505 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 506 |
+
WRITE_DQ : tl.constexpr = True
|
| 507 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 508 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 509 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 510 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 511 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 512 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 513 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 514 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 515 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 516 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 517 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 518 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 519 |
+
USE_TMA : tl.constexpr = False
|
| 520 |
+
BLOCK_M : tl.constexpr = 128
|
| 521 |
+
BLOCK_N : tl.constexpr = 64
|
| 522 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 523 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 524 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 528 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 529 |
+
|
| 530 |
+
if PRESCALE_QK:
|
| 531 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 532 |
+
|
| 533 |
+
kv_offset = 0
|
| 534 |
+
|
| 535 |
+
# loop over k, v and update accumulator until block_n_end
|
| 536 |
+
for start_n in range(block_n_start, block_n_end):
|
| 537 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 538 |
+
if IS_DIVISIBLE:
|
| 539 |
+
acc, l_i, m_i = forward_block_mn(
|
| 540 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 541 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 542 |
+
# accumulated values
|
| 543 |
+
acc, l_i, m_i,
|
| 544 |
+
# Offsets
|
| 545 |
+
off_z, off_h, offs_m, offs_n,
|
| 546 |
+
# Offsets needed for TMA loads
|
| 547 |
+
kv_start,
|
| 548 |
+
kv_offset,
|
| 549 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 550 |
+
# Strides for K and V
|
| 551 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 552 |
+
IS_FULL_BLOCKS,
|
| 553 |
+
)
|
| 554 |
+
else:
|
| 555 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 556 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 557 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 558 |
+
# to the last block because it's faster a lot.
|
| 559 |
+
acc, l_i, m_i = forward_block_mn(
|
| 560 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 561 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 562 |
+
# accumulated values
|
| 563 |
+
acc, l_i, m_i,
|
| 564 |
+
# Offsets
|
| 565 |
+
off_z, off_h, offs_m, offs_n,
|
| 566 |
+
# Offsets needed for TMA loads
|
| 567 |
+
kv_start,
|
| 568 |
+
kv_offset,
|
| 569 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 570 |
+
# Strides for K and V
|
| 571 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 572 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
offset = get_offset_for_next_block(
|
| 578 |
+
start_n, kv_indices, kv_num_blocks,
|
| 579 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
offs_n = offs_n + offset
|
| 583 |
+
kv_offset += offset
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
return acc, l_i, m_i
|
| 587 |
+
''', device_str='cuda')
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ft/cftmlennkcgyn4ynz7zxqohr2jlirziu3mfte3b4eg5y2466jcwm.py
|
| 591 |
+
# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
|
| 592 |
+
# Source node to ATen node mapping:
|
| 593 |
+
# lse_scaled => mul_9
|
| 594 |
+
# Graph fragment:
|
| 595 |
+
# %buf3 : Tensor = PlaceHolder[target=buf3]
|
| 596 |
+
# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
|
| 597 |
+
# return %mul_9
|
| 598 |
+
triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', '''
|
| 599 |
+
import triton
|
| 600 |
+
import triton.language as tl
|
| 601 |
+
|
| 602 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 603 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 604 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 605 |
+
triton_helpers.set_driver_to_gpu()
|
| 606 |
+
|
| 607 |
+
@triton_heuristics.pointwise(
|
| 608 |
+
size_hints={'x': 32768},
|
| 609 |
+
filename=__file__,
|
| 610 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 611 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 612 |
+
min_elem_per_thread=0
|
| 613 |
+
)
|
| 614 |
+
@triton.jit
|
| 615 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 616 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 617 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 618 |
+
xmask = xindex < xnumel
|
| 619 |
+
x2 = xindex
|
| 620 |
+
x0 = (xindex % ks0)
|
| 621 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 622 |
+
tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
|
| 623 |
+
tmp1 = 0.6931471805599453
|
| 624 |
+
tmp2 = tmp0 * tmp1
|
| 625 |
+
tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
|
| 626 |
+
''', device_str='cuda')
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
async_compile.wait(globals())
|
| 630 |
+
del async_compile
|
| 631 |
+
|
| 632 |
+
class Runner:
|
| 633 |
+
def __init__(self, partitions):
|
| 634 |
+
self.partitions = partitions
|
| 635 |
+
|
| 636 |
+
def recursively_apply_fns(self, fns):
|
| 637 |
+
new_callables = []
|
| 638 |
+
for fn, c in zip(fns, self.partitions):
|
| 639 |
+
new_callables.append(fn(c))
|
| 640 |
+
self.partitions = new_callables
|
| 641 |
+
|
| 642 |
+
def call(self, args):
|
| 643 |
+
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args
|
| 644 |
+
args.clear()
|
| 645 |
+
s50 = arg0_1
|
| 646 |
+
s0 = arg2_1
|
| 647 |
+
s43 = arg4_1
|
| 648 |
+
s37 = arg7_1
|
| 649 |
+
s71 = arg8_1
|
| 650 |
+
assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 651 |
+
assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 652 |
+
assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1))
|
| 653 |
+
assert_size_stride(arg6_1, (1, 1, 5, 5), (25, 25, 5, 1))
|
| 654 |
+
assert_size_stride(arg9_1, (1, 1, 5), (5, 5, 1))
|
| 655 |
+
assert_size_stride(arg10_1, (1, 1, 5), (5, 5, 1))
|
| 656 |
+
assert_size_stride(arg11_1, (1, 1, 5, 5), (25, 25, 5, 1))
|
| 657 |
+
assert_size_stride(arg12_1, (1, 1, 5), (5, 5, 1))
|
| 658 |
+
assert_size_stride(arg13_1, (1, 1, 5, 5), (25, 25, 5, 1))
|
| 659 |
+
assert_size_stride(arg14_1, (1, 1, 5), (5, 5, 1))
|
| 660 |
+
assert_size_stride(arg15_1, (1, 1, 5, 5), (25, 25, 5, 1))
|
| 661 |
+
with torch.cuda._DeviceGuard(1):
|
| 662 |
+
torch.cuda.set_device(1)
|
| 663 |
+
buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 664 |
+
buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 665 |
+
buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
|
| 666 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 667 |
+
stream1 = get_raw_stream(1)
|
| 668 |
+
triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream1)
|
| 669 |
+
del arg10_1
|
| 670 |
+
del arg11_1
|
| 671 |
+
del arg1_1
|
| 672 |
+
del arg3_1
|
| 673 |
+
del arg5_1
|
| 674 |
+
del arg6_1
|
| 675 |
+
del arg9_1
|
| 676 |
+
del buf1
|
| 677 |
+
buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
|
| 678 |
+
# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
|
| 679 |
+
triton_poi_fused_mul_1_xnumel = 32*s37
|
| 680 |
+
stream1 = get_raw_stream(1)
|
| 681 |
+
triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream1)
|
| 682 |
+
del buf0
|
| 683 |
+
return (buf2, buf5, )
|
| 684 |
+
|
| 685 |
+
runner = Runner(partitions=[])
|
| 686 |
+
call = runner.call
|
| 687 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 691 |
+
from torch._dynamo.testing import rand_strided
|
| 692 |
+
from torch._inductor.utils import print_performance
|
| 693 |
+
arg0_1 = 528
|
| 694 |
+
arg1_1 = rand_strided((1, 32, 528, 128), (2162688, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16)
|
| 695 |
+
arg2_1 = 528
|
| 696 |
+
arg3_1 = rand_strided((1, 8, 528, 128), (540672, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16)
|
| 697 |
+
arg4_1 = 528
|
| 698 |
+
arg5_1 = rand_strided((1, 8, 528, 128), (540672, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16)
|
| 699 |
+
arg6_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:1', dtype=torch.int32)
|
| 700 |
+
arg7_1 = 528
|
| 701 |
+
arg8_1 = 528
|
| 702 |
+
arg9_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:1', dtype=torch.int32)
|
| 703 |
+
arg10_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:1', dtype=torch.int32)
|
| 704 |
+
arg11_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:1', dtype=torch.int32)
|
| 705 |
+
arg12_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:1', dtype=torch.int32)
|
| 706 |
+
arg13_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:1', dtype=torch.int32)
|
| 707 |
+
arg14_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:1', dtype=torch.int32)
|
| 708 |
+
arg15_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:1', dtype=torch.int32)
|
| 709 |
+
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1])
|
| 710 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
if __name__ == "__main__":
|
| 714 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 715 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/3q/c3qbvcsx2w7qss2v3eocuadgz6t35joo33bflzqkxzzj747zcjpk.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.reduction(
|
| 11 |
+
size_hints={'x': 65536, 'r0_': 128},
|
| 12 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 19 |
+
r0_numel = 128
|
| 20 |
+
rnumel = r0_numel
|
| 21 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 22 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 23 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 24 |
+
xmask = xindex < xnumel
|
| 25 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 26 |
+
rbase = r0_base
|
| 27 |
+
x0 = (xindex % ks0)
|
| 28 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 29 |
+
_tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 30 |
+
x3 = xindex
|
| 31 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 32 |
+
r0_index = r0_offset + r0_base
|
| 33 |
+
r0_mask = r0_index < r0_numel
|
| 34 |
+
roffset = r0_offset
|
| 35 |
+
rindex = r0_index
|
| 36 |
+
r0_2 = r0_index
|
| 37 |
+
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 38 |
+
tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 39 |
+
tmp2 = tmp0 * tmp1
|
| 40 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 41 |
+
tmp5 = _tmp4 + tmp3
|
| 42 |
+
_tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
|
| 43 |
+
tmp4 = tl.sum(_tmp4, 1)[:, None]
|
| 44 |
+
tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
|
| 45 |
+
tmp6 = tmp4.to(tl.float32)
|
| 46 |
+
tmp8 = 0.6931471805599453
|
| 47 |
+
tmp9 = tmp7 * tmp8
|
| 48 |
+
tmp10 = 1.4426950408889634
|
| 49 |
+
tmp11 = tmp9 * tmp10
|
| 50 |
+
tmp12 = tmp6 - tmp11
|
| 51 |
+
tl.store(out_ptr1 + (x3), tmp12, xmask)
|
progress/SpecForge/cache/compiled_kernels/3q/fc5920467dd1501963c976e2b895fc37747fdebfa098fff912209055f3a31828.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "2685c2d349c32243d4ee216505dfdf1e257d04d8316595ed69d4ca3499146788", "found_by_coordesc": false, "time_taken_ms": 53, "triton_cache_hash": "GBIQTIXLLLI56EMJONBW74RZJ42E6PTSU5N7LA23N4VBEBKK3HNQ"}
|
progress/SpecForge/cache/compiled_kernels/3r/c3rkwwyedldrjz6sidtx5huqcsdgpdpu4xndmm6h4e4boo6cbg2w.py
ADDED
|
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['0_inference']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/nj/cnjtse3xftpnmqvwojj6g7ajl3r3hvxbz3sgyaaznnrxcs7gzj2e.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# flex_attention => flex_attention
|
| 41 |
+
# Graph fragment:
|
| 42 |
+
# %arg0_1 : Tensor "bf16[1, 32, 976, 128][3997696, 128, 4096, 1]cuda:7" = PlaceHolder[target=arg0_1]
|
| 43 |
+
# %arg1_1 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg1_1]
|
| 44 |
+
# %arg2_1 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg2_1]
|
| 45 |
+
# %getitem_1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=getitem_1]
|
| 46 |
+
# %buf1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=buf1]
|
| 47 |
+
# %arg3_1 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=arg3_1]
|
| 48 |
+
# %arg4_1 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=arg4_1]
|
| 49 |
+
# %arg5_1 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=arg5_1]
|
| 50 |
+
# %arg6_1 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=arg6_1]
|
| 51 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (976, 976, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 52 |
+
# return %getitem
|
| 53 |
+
triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
|
| 54 |
+
import triton
|
| 55 |
+
import triton.language as tl
|
| 56 |
+
|
| 57 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 58 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 59 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 60 |
+
|
| 61 |
+
@triton_heuristics.template(
|
| 62 |
+
|
| 63 |
+
num_stages=3,
|
| 64 |
+
num_warps=8,
|
| 65 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 66 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 67 |
+
|
| 68 |
+
)
|
| 69 |
+
@triton.jit
|
| 70 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0):
|
| 71 |
+
PRESCALE_QK : tl.constexpr = False
|
| 72 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 73 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 74 |
+
WRITE_DQ : tl.constexpr = True
|
| 75 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 76 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 77 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 78 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 79 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 80 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 81 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 82 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 83 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 84 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 85 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 86 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 87 |
+
USE_TMA : tl.constexpr = False
|
| 88 |
+
BLOCK_M : tl.constexpr = 128
|
| 89 |
+
BLOCK_N : tl.constexpr = 64
|
| 90 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 91 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 92 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 93 |
+
Q = arg_Q
|
| 94 |
+
K = arg_K
|
| 95 |
+
V = arg_V
|
| 96 |
+
LSE = arg_LSE
|
| 97 |
+
MAX = arg_MAX
|
| 98 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 99 |
+
KV_IDX = arg_KV_IDX
|
| 100 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 101 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 102 |
+
|
| 103 |
+
# Sub notation for this kernel:
|
| 104 |
+
#
|
| 105 |
+
# Q: Query, K: Key, V: Value
|
| 106 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 107 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 108 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 109 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 110 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 111 |
+
#
|
| 112 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 113 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 114 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 115 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 116 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 117 |
+
#
|
| 118 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 119 |
+
#
|
| 120 |
+
# (Modifiable) Performance tuning options
|
| 121 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 122 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 123 |
+
|
| 124 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 125 |
+
# or involve a numerics vs. perf tradeoff
|
| 126 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 127 |
+
# about 20% more numerical error, but slightly faster.
|
| 128 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 129 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 130 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 131 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 132 |
+
|
| 133 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 134 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 135 |
+
|
| 136 |
+
# Define strides of inputs
|
| 137 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 3997696, 128, 4096, 1
|
| 138 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 999424, 128, 1024, 1
|
| 139 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 999424, 128, 1024, 1
|
| 140 |
+
|
| 141 |
+
ZQ = 1
|
| 142 |
+
HQ = 32
|
| 143 |
+
Q_LEN = 976
|
| 144 |
+
ZKV = 1
|
| 145 |
+
KV_LEN = 976
|
| 146 |
+
|
| 147 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 148 |
+
|
| 149 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 150 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 151 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 152 |
+
|
| 153 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 154 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 155 |
+
off_zkv = off_zq % ZKV
|
| 156 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 157 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 158 |
+
|
| 159 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 160 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 161 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 162 |
+
|
| 163 |
+
Q = Q + q_offset
|
| 164 |
+
K = K + k_offset
|
| 165 |
+
V = V + v_offset
|
| 166 |
+
|
| 167 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 168 |
+
desc_q = None
|
| 169 |
+
desc_k = None
|
| 170 |
+
desc_v = None
|
| 171 |
+
|
| 172 |
+
SPARSE_Z = 1
|
| 173 |
+
SPARSE_HQ = 1
|
| 174 |
+
|
| 175 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 176 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 177 |
+
|
| 178 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 179 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 180 |
+
|
| 181 |
+
stride_kv_num_blks_h = 8
|
| 182 |
+
stride_kv_idx_h = 64
|
| 183 |
+
stride_kv_idx_m = 8
|
| 184 |
+
|
| 185 |
+
# initialize pointer to m and l
|
| 186 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 187 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 188 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 189 |
+
|
| 190 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 191 |
+
|
| 192 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 193 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 194 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 195 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 196 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 197 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 198 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 199 |
+
|
| 200 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 201 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 202 |
+
# both score_mod and mask_mod to it
|
| 203 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 204 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 205 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 206 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# K and V pointers will be passed directly to forward_inner
|
| 210 |
+
|
| 211 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
acc, l_i, m_i = forward_inner(
|
| 215 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
|
| 216 |
+
q, K, V,
|
| 217 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 218 |
+
acc, l_i, m_i,
|
| 219 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 220 |
+
kv_start,
|
| 221 |
+
kv_indices, kv_num_blocks,
|
| 222 |
+
0, block_n_end,
|
| 223 |
+
MATMUL_PRECISION,
|
| 224 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 225 |
+
IS_FULL_BLOCKS=False,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 229 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 230 |
+
# apply mask_mod to them - only score_mod
|
| 231 |
+
if HAS_FULL_BLOCKS:
|
| 232 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 233 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 234 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 235 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 236 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 237 |
+
# K and V pointers will be passed directly to forward_inner
|
| 238 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 239 |
+
|
| 240 |
+
acc, l_i, m_i = forward_inner(
|
| 241 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
|
| 242 |
+
q, K, V,
|
| 243 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 244 |
+
acc, l_i, m_i,
|
| 245 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 246 |
+
kv_start,
|
| 247 |
+
kv_indices, kv_num_blocks,
|
| 248 |
+
0, block_n_end,
|
| 249 |
+
MATMUL_PRECISION,
|
| 250 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 251 |
+
IS_FULL_BLOCKS=True,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# [Note] Handle fully masked out rows:
|
| 256 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 257 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 258 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 259 |
+
|
| 260 |
+
acc = acc / l_i[:, None]
|
| 261 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 262 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 263 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 264 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 265 |
+
|
| 266 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 267 |
+
|
| 268 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 269 |
+
xindex = idx_d + 128*idx_m + 124928*idx_hq + 3997696*idx_zq
|
| 270 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 271 |
+
|
| 272 |
+
if OUTPUT_LOGSUMEXP:
|
| 273 |
+
off_hz = off_zq * HQ + off_hq
|
| 274 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 275 |
+
lse = m_i + tl.math.log2(l_i)
|
| 276 |
+
if IS_DIVISIBLE:
|
| 277 |
+
tl.store(l_ptrs, lse)
|
| 278 |
+
else:
|
| 279 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 280 |
+
|
| 281 |
+
if OUTPUT_MAX:
|
| 282 |
+
off_hz = off_zq * HQ + off_hq
|
| 283 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 284 |
+
if IS_DIVISIBLE:
|
| 285 |
+
tl.store(max_ptrs, m_i)
|
| 286 |
+
else:
|
| 287 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# Utility triton funcs
|
| 291 |
+
@triton.jit
|
| 292 |
+
def get_offset_for_next_block(
|
| 293 |
+
loop_iter, col_indices, total_blocks,
|
| 294 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 295 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 296 |
+
):
|
| 297 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 298 |
+
return BLOCK
|
| 299 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 300 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 301 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 302 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 303 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 304 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 305 |
+
return offset
|
| 306 |
+
|
| 307 |
+
@triton.jit
|
| 308 |
+
def get_bounded_indices(indices, max_len=None):
|
| 309 |
+
return indices % max_len if max_len is not None else indices
|
| 310 |
+
|
| 311 |
+
@triton.jit
|
| 312 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 313 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 314 |
+
return tl.load(block_ptr)
|
| 315 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 316 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 317 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 318 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 319 |
+
else:
|
| 320 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 321 |
+
|
| 322 |
+
@triton.jit
|
| 323 |
+
def load_checked_2d(
|
| 324 |
+
ptr,
|
| 325 |
+
offs_m,
|
| 326 |
+
offs_n,
|
| 327 |
+
stride_m,
|
| 328 |
+
stride_n,
|
| 329 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 330 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 331 |
+
M_LEN: tl.constexpr,
|
| 332 |
+
N_LEN: tl.constexpr,
|
| 333 |
+
):
|
| 334 |
+
# Calculate final pointer if strides are provided
|
| 335 |
+
if stride_m is not None and stride_n is not None:
|
| 336 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 337 |
+
|
| 338 |
+
# Handle all masking cases
|
| 339 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 340 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 341 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 342 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 343 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 344 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 345 |
+
else: # Both divisible
|
| 346 |
+
return tl.load(ptr)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# Common Imports
|
| 350 |
+
@triton.jit
|
| 351 |
+
def forward_block_mn(
|
| 352 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
|
| 353 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 354 |
+
# accumulated values
|
| 355 |
+
acc, l_i, m_i,
|
| 356 |
+
# Offsets
|
| 357 |
+
off_z, off_h, offs_m, offs_n,
|
| 358 |
+
# Offsets needed for TMA loads
|
| 359 |
+
kv_start,
|
| 360 |
+
kv_offset,
|
| 361 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 362 |
+
# Strides for K and V
|
| 363 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 364 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 365 |
+
|
| 366 |
+
):
|
| 367 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 368 |
+
PRESCALE_QK : tl.constexpr = False
|
| 369 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 370 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 371 |
+
WRITE_DQ : tl.constexpr = True
|
| 372 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 373 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 374 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 375 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 376 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 377 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 378 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 379 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 380 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 381 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 382 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 383 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 384 |
+
USE_TMA : tl.constexpr = False
|
| 385 |
+
BLOCK_M : tl.constexpr = 128
|
| 386 |
+
BLOCK_N : tl.constexpr = 64
|
| 387 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 388 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 389 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# -- load k --
|
| 393 |
+
# NB reversed order to since K is transposed
|
| 394 |
+
kv_base_offset = kv_start + kv_offset
|
| 395 |
+
|
| 396 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 397 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 398 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 399 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 400 |
+
|
| 401 |
+
k = tl.trans(k)
|
| 402 |
+
# -- compute qk ---
|
| 403 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 404 |
+
if not PRESCALE_QK:
|
| 405 |
+
qk *= SM_SCALE
|
| 406 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 407 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 408 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 409 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 410 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 411 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 412 |
+
|
| 413 |
+
tmp0 = (qk)
|
| 414 |
+
post_mod_scores = tmp0
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 418 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 419 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 420 |
+
|
| 421 |
+
if not IS_FULL_BLOCKS:
|
| 422 |
+
tmp1 = (m)
|
| 423 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 424 |
+
tmp3 = tmp1 < tmp2
|
| 425 |
+
tmp4 = (n)
|
| 426 |
+
tmp5 = tmp4 <= tmp1
|
| 427 |
+
tmp6 = tmp3 & tmp5
|
| 428 |
+
tmp7 = tmp1 >= tmp2
|
| 429 |
+
tmp8 = tmp4 < tmp2
|
| 430 |
+
tmp9 = tmp7 & tmp8
|
| 431 |
+
tmp10 = tmp8 == 0
|
| 432 |
+
tmp11 = tmp7 & tmp10
|
| 433 |
+
tmp12 = tmp1 - tmp2
|
| 434 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 435 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 436 |
+
tmp15 = tmp4 - tmp2
|
| 437 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 438 |
+
tmp17 = tmp14 == tmp16
|
| 439 |
+
tmp18 = tmp11 & tmp17
|
| 440 |
+
tmp19 = tmp9 | tmp18
|
| 441 |
+
tmp20 = tmp6 | tmp19
|
| 442 |
+
mask_mod_output = tmp20
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 446 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 447 |
+
# apply mask for partially unmasked blocks
|
| 448 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 449 |
+
|
| 450 |
+
if not PRESCALE_QK:
|
| 451 |
+
post_mod_scores *= RCP_LN2
|
| 452 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 453 |
+
|
| 454 |
+
# -- compute scaling constant ---
|
| 455 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 456 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 457 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 458 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 459 |
+
else:
|
| 460 |
+
m_ij_masked = m_ij
|
| 461 |
+
|
| 462 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 463 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 464 |
+
|
| 465 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 466 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 467 |
+
# m_ij
|
| 468 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 469 |
+
# # -- scale and update acc --
|
| 470 |
+
acc = acc * alpha[:, None]
|
| 471 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 472 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 473 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 474 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 475 |
+
|
| 476 |
+
# -- update m_i
|
| 477 |
+
m_i = m_ij
|
| 478 |
+
|
| 479 |
+
return acc, l_i, m_i
|
| 480 |
+
|
| 481 |
+
@triton.jit
|
| 482 |
+
def forward_inner(
|
| 483 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
|
| 484 |
+
q, K, V,
|
| 485 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 486 |
+
# accumulated values
|
| 487 |
+
acc, l_i, m_i,
|
| 488 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 489 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 490 |
+
off_z, off_h, offs_m, offs_n,
|
| 491 |
+
# Offsets needed for TMA loads
|
| 492 |
+
kv_start,
|
| 493 |
+
# blocksparse data
|
| 494 |
+
kv_indices, kv_num_blocks,
|
| 495 |
+
# start kv and end kv block
|
| 496 |
+
block_n_start, block_n_end,
|
| 497 |
+
MATMUL_PRECISION,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
):
|
| 502 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 503 |
+
PRESCALE_QK : tl.constexpr = False
|
| 504 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 505 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 506 |
+
WRITE_DQ : tl.constexpr = True
|
| 507 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 508 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 509 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 510 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 511 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 512 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 513 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 514 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 515 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 516 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 517 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 518 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 519 |
+
USE_TMA : tl.constexpr = False
|
| 520 |
+
BLOCK_M : tl.constexpr = 128
|
| 521 |
+
BLOCK_N : tl.constexpr = 64
|
| 522 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 523 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 524 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 528 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 529 |
+
|
| 530 |
+
if PRESCALE_QK:
|
| 531 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 532 |
+
|
| 533 |
+
kv_offset = 0
|
| 534 |
+
|
| 535 |
+
# loop over k, v and update accumulator until block_n_end
|
| 536 |
+
for start_n in range(block_n_start, block_n_end):
|
| 537 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 538 |
+
if IS_DIVISIBLE:
|
| 539 |
+
acc, l_i, m_i = forward_block_mn(
|
| 540 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
|
| 541 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 542 |
+
# accumulated values
|
| 543 |
+
acc, l_i, m_i,
|
| 544 |
+
# Offsets
|
| 545 |
+
off_z, off_h, offs_m, offs_n,
|
| 546 |
+
# Offsets needed for TMA loads
|
| 547 |
+
kv_start,
|
| 548 |
+
kv_offset,
|
| 549 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 550 |
+
# Strides for K and V
|
| 551 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 552 |
+
IS_FULL_BLOCKS,
|
| 553 |
+
)
|
| 554 |
+
else:
|
| 555 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 556 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 557 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 558 |
+
# to the last block because it's faster a lot.
|
| 559 |
+
acc, l_i, m_i = forward_block_mn(
|
| 560 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
|
| 561 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 562 |
+
# accumulated values
|
| 563 |
+
acc, l_i, m_i,
|
| 564 |
+
# Offsets
|
| 565 |
+
off_z, off_h, offs_m, offs_n,
|
| 566 |
+
# Offsets needed for TMA loads
|
| 567 |
+
kv_start,
|
| 568 |
+
kv_offset,
|
| 569 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 570 |
+
# Strides for K and V
|
| 571 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 572 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
offset = get_offset_for_next_block(
|
| 578 |
+
start_n, kv_indices, kv_num_blocks,
|
| 579 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
offs_n = offs_n + offset
|
| 583 |
+
kv_offset += offset
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
return acc, l_i, m_i
|
| 587 |
+
''', device_str='cuda')
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py
|
| 591 |
+
# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
|
| 592 |
+
# Source node to ATen node mapping:
|
| 593 |
+
# lse_scaled => mul
|
| 594 |
+
# Graph fragment:
|
| 595 |
+
# %buf3 : Tensor = PlaceHolder[target=buf3]
|
| 596 |
+
# %mul : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
|
| 597 |
+
# return %mul
|
| 598 |
+
triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', '''
|
| 599 |
+
import triton
|
| 600 |
+
import triton.language as tl
|
| 601 |
+
|
| 602 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 603 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 604 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 605 |
+
triton_helpers.set_driver_to_gpu()
|
| 606 |
+
|
| 607 |
+
@triton_heuristics.pointwise(
|
| 608 |
+
size_hints={'x': 32768},
|
| 609 |
+
filename=__file__,
|
| 610 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
|
| 611 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 249856}},
|
| 612 |
+
min_elem_per_thread=0
|
| 613 |
+
)
|
| 614 |
+
@triton.jit
|
| 615 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 616 |
+
xnumel = 31232
|
| 617 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 618 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 619 |
+
xmask = xindex < xnumel
|
| 620 |
+
x0 = xindex
|
| 621 |
+
tmp0 = tl.load(in_ptr0 + (x0), xmask)
|
| 622 |
+
tmp1 = 0.6931471805599453
|
| 623 |
+
tmp2 = tmp0 * tmp1
|
| 624 |
+
tl.store(out_ptr0 + (x0), tmp2, xmask)
|
| 625 |
+
''', device_str='cuda')
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
async_compile.wait(globals())
|
| 629 |
+
del async_compile
|
| 630 |
+
|
| 631 |
+
class Runner:
|
| 632 |
+
def __init__(self, partitions):
|
| 633 |
+
self.partitions = partitions
|
| 634 |
+
|
| 635 |
+
def recursively_apply_fns(self, fns):
|
| 636 |
+
new_callables = []
|
| 637 |
+
for fn, c in zip(fns, self.partitions):
|
| 638 |
+
new_callables.append(fn(c))
|
| 639 |
+
self.partitions = new_callables
|
| 640 |
+
|
| 641 |
+
def call(self, args):
|
| 642 |
+
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args
|
| 643 |
+
args.clear()
|
| 644 |
+
assert_size_stride(arg0_1, (1, 32, 976, 128), (3997696, 128, 4096, 1))
|
| 645 |
+
assert_size_stride(arg1_1, (1, 8, 976, 128), (999424, 128, 1024, 1))
|
| 646 |
+
assert_size_stride(arg2_1, (1, 8, 976, 128), (999424, 128, 1024, 1))
|
| 647 |
+
assert_size_stride(arg3_1, (1, 1, 8), (8, 8, 1))
|
| 648 |
+
assert_size_stride(arg4_1, (1, 1, 8, 8), (64, 64, 8, 1))
|
| 649 |
+
assert_size_stride(arg5_1, (1, 1, 8), (8, 8, 1))
|
| 650 |
+
assert_size_stride(arg6_1, (1, 1, 8, 8), (64, 64, 8, 1))
|
| 651 |
+
assert_size_stride(arg7_1, (1, 1, 8), (8, 8, 1))
|
| 652 |
+
assert_size_stride(arg8_1, (1, 1, 8, 8), (64, 64, 8, 1))
|
| 653 |
+
assert_size_stride(arg9_1, (1, 1, 8), (8, 8, 1))
|
| 654 |
+
assert_size_stride(arg10_1, (1, 1, 8, 8), (64, 64, 8, 1))
|
| 655 |
+
with torch.cuda._DeviceGuard(7):
|
| 656 |
+
torch.cuda.set_device(7)
|
| 657 |
+
buf0 = empty_strided_cuda((1, 32, 976), (31232, 976, 1), torch.float32)
|
| 658 |
+
buf1 = empty_strided_cuda((1, 32, 976), (31232, 976, 1), torch.float32)
|
| 659 |
+
buf2 = empty_strided_cuda((1, 32, 976, 128), (3997696, 128, 4096, 1), torch.bfloat16)
|
| 660 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 661 |
+
stream7 = get_raw_stream(7)
|
| 662 |
+
triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 8, 1, 32, stream=stream7)
|
| 663 |
+
del arg0_1
|
| 664 |
+
del arg1_1
|
| 665 |
+
del arg2_1
|
| 666 |
+
del arg3_1
|
| 667 |
+
del arg4_1
|
| 668 |
+
del arg5_1
|
| 669 |
+
del arg6_1
|
| 670 |
+
buf5 = buf1; del buf1 # reuse
|
| 671 |
+
# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
|
| 672 |
+
stream7 = get_raw_stream(7)
|
| 673 |
+
triton_poi_fused_mul_1.run(buf0, buf5, 31232, stream=stream7)
|
| 674 |
+
del buf0
|
| 675 |
+
return (buf2, buf5, )
|
| 676 |
+
|
| 677 |
+
runner = Runner(partitions=[])
|
| 678 |
+
call = runner.call
|
| 679 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 683 |
+
from torch._dynamo.testing import rand_strided
|
| 684 |
+
from torch._inductor.utils import print_performance
|
| 685 |
+
arg0_1 = rand_strided((1, 32, 976, 128), (3997696, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 686 |
+
arg1_1 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 687 |
+
arg2_1 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 688 |
+
arg3_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32)
|
| 689 |
+
arg4_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32)
|
| 690 |
+
arg5_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32)
|
| 691 |
+
arg6_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32)
|
| 692 |
+
arg7_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32)
|
| 693 |
+
arg8_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32)
|
| 694 |
+
arg9_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32)
|
| 695 |
+
arg10_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32)
|
| 696 |
+
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1])
|
| 697 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
if __name__ == "__main__":
|
| 701 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 702 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/3z/c3zi2pt6zmbthc6ythgt5p4ednhp6m24gpscb2pt6adf6xojetua.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 36 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 37 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 38 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 39 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 41 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 42 |
+
Q = arg_Q
|
| 43 |
+
K = arg_K
|
| 44 |
+
V = arg_V
|
| 45 |
+
LSE = arg_LSE
|
| 46 |
+
DELTA = arg_DELTA
|
| 47 |
+
DO = arg_DO
|
| 48 |
+
DQ = arg_DQ
|
| 49 |
+
DV = arg_DV
|
| 50 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 51 |
+
KV_IDX = arg_KV_IDX
|
| 52 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 53 |
+
Q_IDX = arg_Q_IDX
|
| 54 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 55 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 56 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 57 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 58 |
+
|
| 59 |
+
# Sub notation for this kernel:
|
| 60 |
+
#
|
| 61 |
+
# Q: Query, K: Key, V: Value
|
| 62 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 63 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 64 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 65 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 66 |
+
# inductor codegen
|
| 67 |
+
# M: Number of queries, N: Number of keys/values
|
| 68 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 69 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 70 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 71 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 72 |
+
# (Modifiable) Performance tuning options
|
| 73 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 74 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 75 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 76 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 77 |
+
#
|
| 78 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 79 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 80 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 81 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 82 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 83 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 84 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 85 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 86 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 87 |
+
|
| 88 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 89 |
+
# or involve a numerics vs. perf tradeoff
|
| 90 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 91 |
+
# about 20% more numerical error, but slightly faster.
|
| 92 |
+
|
| 93 |
+
# Define strides of inputs
|
| 94 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
|
| 95 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
|
| 96 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
|
| 97 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
|
| 98 |
+
|
| 99 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
|
| 100 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
|
| 101 |
+
|
| 102 |
+
ZQ = 1
|
| 103 |
+
HQ = 32
|
| 104 |
+
HKV = 8
|
| 105 |
+
Q_LEN = ks0
|
| 106 |
+
ZKV = 1
|
| 107 |
+
KV_LEN = ks1
|
| 108 |
+
|
| 109 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 110 |
+
|
| 111 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 112 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 113 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 114 |
+
|
| 115 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 116 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 117 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 118 |
+
|
| 119 |
+
SPARSE_Z = 1
|
| 120 |
+
SPARSE_HQ = 1
|
| 121 |
+
|
| 122 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 123 |
+
|
| 124 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 125 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 126 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 127 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 128 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 129 |
+
|
| 130 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 131 |
+
K += k_adj
|
| 132 |
+
V += v_adj
|
| 133 |
+
DV += dv_adj
|
| 134 |
+
|
| 135 |
+
RCP_LN2 = 1.44269504
|
| 136 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 137 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 138 |
+
|
| 139 |
+
if pid >= NUM_KV_BLOCKS:
|
| 140 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 141 |
+
# THIS BLOCK DOES DQ
|
| 142 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 143 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 144 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 145 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 146 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 147 |
+
stride_kv_num_blks_h = 1
|
| 148 |
+
stride_kv_idx_h = 1
|
| 149 |
+
stride_kv_idx_m = 1
|
| 150 |
+
|
| 151 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 152 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 153 |
+
|
| 154 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 155 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 156 |
+
|
| 157 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 158 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 159 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 160 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 161 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 162 |
+
|
| 163 |
+
Q2 = Q + q_adj2
|
| 164 |
+
DO2 = DO + do_adj2
|
| 165 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 166 |
+
# if Q is broadcasted)
|
| 167 |
+
DQ2 = DQ + dq_adj2
|
| 168 |
+
LSE2 = LSE + off_chz2
|
| 169 |
+
DELTA2 = DELTA + off_chz2
|
| 170 |
+
|
| 171 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 172 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 173 |
+
|
| 174 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 175 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 176 |
+
|
| 177 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 178 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 179 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 180 |
+
|
| 181 |
+
if PRESCALE_QK:
|
| 182 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 183 |
+
|
| 184 |
+
if IS_DIVISIBLE:
|
| 185 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 186 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 187 |
+
else:
|
| 188 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 189 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 190 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 191 |
+
lse = lse[:, None]
|
| 192 |
+
|
| 193 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 194 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 195 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 196 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 197 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 198 |
+
|
| 199 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 200 |
+
dq = bwd_dq_inner(
|
| 201 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 202 |
+
K, V,
|
| 203 |
+
dq, q, do, Di, lse,
|
| 204 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 205 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 206 |
+
kv_indices, sparse_kv_num_blocks,
|
| 207 |
+
MATMUL_PRECISION,
|
| 208 |
+
IS_FULL_BLOCKS=False,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if HAS_FULL_BLOCKS:
|
| 212 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 213 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 214 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 215 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 216 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 217 |
+
|
| 218 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 219 |
+
dq = bwd_dq_inner(
|
| 220 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 221 |
+
K, V,
|
| 222 |
+
dq, q, do, Di, lse,
|
| 223 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 224 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 225 |
+
kv_indices, sparse_kv_num_blocks,
|
| 226 |
+
MATMUL_PRECISION,
|
| 227 |
+
IS_FULL_BLOCKS=True,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Write back dQ.
|
| 231 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 232 |
+
dq *= SM_SCALE
|
| 233 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 234 |
+
tl.store(dq_ptrs, dq)
|
| 235 |
+
else:
|
| 236 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 237 |
+
else:
|
| 238 |
+
# THIS BLOCK DOES DK & DV
|
| 239 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 240 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 241 |
+
|
| 242 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 243 |
+
|
| 244 |
+
stride_q_num_blks_h = 1
|
| 245 |
+
stride_q_idx_h = 1
|
| 246 |
+
stride_q_idx_n = 1
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 250 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 251 |
+
|
| 252 |
+
start_n1 = pid * BLOCK_N1
|
| 253 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 254 |
+
|
| 255 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 256 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 257 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 258 |
+
|
| 259 |
+
if PRESCALE_QK:
|
| 260 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 261 |
+
|
| 262 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 263 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 264 |
+
|
| 265 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 266 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 267 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 268 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 269 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 270 |
+
|
| 271 |
+
Q1 = Q + q_adj1
|
| 272 |
+
DO1 = DO + do_adj1
|
| 273 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 274 |
+
# if Q is broadcasted)
|
| 275 |
+
LSE1 = LSE + off_chz1
|
| 276 |
+
DELTA1 = DELTA + off_chz1
|
| 277 |
+
|
| 278 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 279 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 280 |
+
|
| 281 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 282 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 283 |
+
|
| 284 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 285 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 286 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 287 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 288 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 289 |
+
|
| 290 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 291 |
+
dk, dv = bwd_dkdv_inner(
|
| 292 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 293 |
+
Q1, DO1, DELTA1, LSE1,
|
| 294 |
+
dk, dv, k, v,
|
| 295 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 296 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 297 |
+
q_indices, sparse_q_num_blocks,
|
| 298 |
+
MATMUL_PRECISION,
|
| 299 |
+
IS_FULL_BLOCKS=False,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if HAS_FULL_BLOCKS:
|
| 304 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 305 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 306 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 307 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 308 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 309 |
+
|
| 310 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 311 |
+
dk, dv = bwd_dkdv_inner(
|
| 312 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 313 |
+
Q1, DO1, DELTA1, LSE1,
|
| 314 |
+
dk, dv, k, v,
|
| 315 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 316 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 317 |
+
q_indices, sparse_q_num_blocks,
|
| 318 |
+
MATMUL_PRECISION,
|
| 319 |
+
IS_FULL_BLOCKS=True,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Write back dV and dK.
|
| 323 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 324 |
+
|
| 325 |
+
index_n = offs_n1[:, None]
|
| 326 |
+
index_k = offs_k[None, :]
|
| 327 |
+
index_v = offs_v[None, :]
|
| 328 |
+
|
| 329 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 330 |
+
tl.store(dv_ptrs, dv)
|
| 331 |
+
else:
|
| 332 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 333 |
+
|
| 334 |
+
dk *= SM_SCALE
|
| 335 |
+
|
| 336 |
+
if SAFE_HEAD_DIM:
|
| 337 |
+
mask = index_n < KV_LEN
|
| 338 |
+
else:
|
| 339 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 340 |
+
|
| 341 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 342 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 343 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 344 |
+
xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
|
| 345 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 346 |
+
|
| 347 |
+
@triton.jit
|
| 348 |
+
def bwd_dq_inner(
|
| 349 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 350 |
+
K, V, # pointers
|
| 351 |
+
dq, q, do, Di, lse,
|
| 352 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 353 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 354 |
+
kv_indices, sparse_kv_num_blocks,
|
| 355 |
+
MATMUL_PRECISION,
|
| 356 |
+
IS_FULL_BLOCKS,
|
| 357 |
+
):
|
| 358 |
+
PRESCALE_QK : tl.constexpr = False
|
| 359 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 360 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 361 |
+
WRITE_DQ : tl.constexpr = True
|
| 362 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 363 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 364 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 365 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 366 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 367 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 368 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 369 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 370 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 371 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 372 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 373 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 374 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 375 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 376 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 377 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 378 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 379 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 380 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 381 |
+
|
| 382 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 383 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 384 |
+
Q_LEN = ks0
|
| 385 |
+
KV_LEN = ks1
|
| 386 |
+
|
| 387 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 388 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 389 |
+
|
| 390 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 391 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 392 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 393 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 394 |
+
|
| 395 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 396 |
+
|
| 397 |
+
for start_n in range(0, hi):
|
| 398 |
+
dq = bwd_dq_block_mn(
|
| 399 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 400 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 401 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 402 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 403 |
+
kv_indices, sparse_kv_num_blocks,
|
| 404 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 405 |
+
IS_FULL_BLOCKS,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Increment pointers.
|
| 409 |
+
offset = get_offset_for_next_block(
|
| 410 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 411 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
kT_ptrs += offset * stride_kn
|
| 415 |
+
vT_ptrs += offset * stride_vn
|
| 416 |
+
|
| 417 |
+
offs_n2 += offset
|
| 418 |
+
|
| 419 |
+
return dq
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
@triton.jit
|
| 423 |
+
def bwd_dq_block_mn(
|
| 424 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 425 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 426 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 427 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 428 |
+
kv_indices, sparse_kv_num_blocks,
|
| 429 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 430 |
+
IS_FULL_BLOCKS,
|
| 431 |
+
):
|
| 432 |
+
PRESCALE_QK : tl.constexpr = False
|
| 433 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 434 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 435 |
+
WRITE_DQ : tl.constexpr = True
|
| 436 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 437 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 438 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 439 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 440 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 441 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 442 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 443 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 444 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 445 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 446 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 447 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 448 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 449 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 450 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 451 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 452 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 453 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 454 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# NB reversed order to since K is transposed
|
| 458 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 459 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 460 |
+
if not PRESCALE_QK:
|
| 461 |
+
qk *= SM_SCALE
|
| 462 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 463 |
+
pre_mod_scores = qk
|
| 464 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 465 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 466 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 467 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 468 |
+
|
| 469 |
+
tmp0 = (qk)
|
| 470 |
+
post_mod_scores = tmp0
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
if not IS_DIVISIBLE:
|
| 476 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 477 |
+
|
| 478 |
+
if not IS_FULL_BLOCKS:
|
| 479 |
+
tmp1 = (m)
|
| 480 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 481 |
+
tmp3 = tmp1 < tmp2
|
| 482 |
+
tmp4 = (n)
|
| 483 |
+
tmp5 = tmp4 <= tmp1
|
| 484 |
+
tmp6 = tmp3 & tmp5
|
| 485 |
+
tmp7 = tmp1 >= tmp2
|
| 486 |
+
tmp8 = tmp4 < tmp2
|
| 487 |
+
tmp9 = tmp7 & tmp8
|
| 488 |
+
tmp10 = tmp8 == 0
|
| 489 |
+
tmp11 = tmp7 & tmp10
|
| 490 |
+
tmp12 = tmp1 - tmp2
|
| 491 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 492 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 493 |
+
tmp15 = tmp4 - tmp2
|
| 494 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 495 |
+
tmp17 = tmp14 == tmp16
|
| 496 |
+
tmp18 = tmp11 & tmp17
|
| 497 |
+
tmp19 = tmp9 | tmp18
|
| 498 |
+
tmp20 = tmp6 | tmp19
|
| 499 |
+
mask_mod_output = tmp20
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# apply mask for partial masked block
|
| 503 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 504 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 505 |
+
if not PRESCALE_QK:
|
| 506 |
+
post_mod_scores *= RCP_LN2
|
| 507 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 508 |
+
# Compute dP and dS.
|
| 509 |
+
# NB reversed order to since V is transposed
|
| 510 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 511 |
+
|
| 512 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 513 |
+
ds = p * (dp - Di[:, None])
|
| 514 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 515 |
+
tmp21 = (ds)
|
| 516 |
+
grad_scores = tmp21
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
if not IS_DIVISIBLE:
|
| 520 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 521 |
+
|
| 522 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 523 |
+
if WRITE_DQ:
|
| 524 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 525 |
+
|
| 526 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 527 |
+
ds = grad_scores
|
| 528 |
+
|
| 529 |
+
if not IS_FULL_BLOCKS:
|
| 530 |
+
# (grads) apply mask for partially unmasked block
|
| 531 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 532 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 533 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 534 |
+
# Compute dQ.
|
| 535 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 536 |
+
|
| 537 |
+
return dq
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@triton.jit
|
| 541 |
+
def bwd_dkdv_inner(
|
| 542 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 543 |
+
Q, DO, DELTA, LSE, # pointers
|
| 544 |
+
dk, dv, k, v,
|
| 545 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 546 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 547 |
+
q_indices, sparse_q_num_blocks,
|
| 548 |
+
MATMUL_PRECISION,
|
| 549 |
+
IS_FULL_BLOCKS,
|
| 550 |
+
):
|
| 551 |
+
PRESCALE_QK : tl.constexpr = False
|
| 552 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 553 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 554 |
+
WRITE_DQ : tl.constexpr = True
|
| 555 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 556 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 557 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 558 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 559 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 560 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 561 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 562 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 563 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 564 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 565 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 566 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 567 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 568 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 569 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 570 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 571 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 572 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 573 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 574 |
+
|
| 575 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 576 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 577 |
+
Q_LEN = ks0
|
| 578 |
+
KV_LEN = ks1
|
| 579 |
+
|
| 580 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 581 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 582 |
+
|
| 583 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 584 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 585 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 586 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 587 |
+
|
| 588 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 589 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 590 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 591 |
+
|
| 592 |
+
for start_m in range(0, hi):
|
| 593 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 594 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 595 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 596 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 597 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 598 |
+
q_indices, sparse_q_num_blocks,
|
| 599 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 600 |
+
IS_FULL_BLOCKS,
|
| 601 |
+
)
|
| 602 |
+
# Increment pointers.
|
| 603 |
+
offset = get_offset_for_next_block(
|
| 604 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 605 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
qT_ptrs += offset * stride_qm
|
| 609 |
+
do_ptrs += offset * stride_dom
|
| 610 |
+
offs_m1 += offset
|
| 611 |
+
|
| 612 |
+
return dk, dv
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
@triton.jit
|
| 616 |
+
def bwd_dkdv_block_mn(
|
| 617 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 618 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 619 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 620 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 621 |
+
q_indices, sparse_q_num_blocks,
|
| 622 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 623 |
+
IS_FULL_BLOCKS,
|
| 624 |
+
):
|
| 625 |
+
PRESCALE_QK : tl.constexpr = False
|
| 626 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 627 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 628 |
+
WRITE_DQ : tl.constexpr = True
|
| 629 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 630 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 631 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 632 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 633 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 634 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 635 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 636 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 637 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 638 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 639 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 640 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 641 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 642 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 643 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 644 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 645 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 646 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 647 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
# NB reversed order since Q is transposed
|
| 651 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 652 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 653 |
+
if IS_DIVISIBLE:
|
| 654 |
+
lse = tl.load(LSE + offs_m1)
|
| 655 |
+
else:
|
| 656 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 657 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 658 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 659 |
+
if not PRESCALE_QK:
|
| 660 |
+
qkT *= SM_SCALE
|
| 661 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 662 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 663 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 664 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 665 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 666 |
+
|
| 667 |
+
pre_mod_scores = qkT
|
| 668 |
+
tmp22 = (qkT)
|
| 669 |
+
post_mod_scores = tmp22
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
if not IS_DIVISIBLE:
|
| 674 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 675 |
+
|
| 676 |
+
if not IS_FULL_BLOCKS:
|
| 677 |
+
tmp23 = (m)
|
| 678 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 679 |
+
tmp25 = tmp23 < tmp24
|
| 680 |
+
tmp26 = (n)
|
| 681 |
+
tmp27 = tmp26 <= tmp23
|
| 682 |
+
tmp28 = tmp25 & tmp27
|
| 683 |
+
tmp29 = tmp23 >= tmp24
|
| 684 |
+
tmp30 = tmp26 < tmp24
|
| 685 |
+
tmp31 = tmp29 & tmp30
|
| 686 |
+
tmp32 = tmp30 == 0
|
| 687 |
+
tmp33 = tmp29 & tmp32
|
| 688 |
+
tmp34 = tmp23 - tmp24
|
| 689 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 690 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 691 |
+
tmp37 = tmp26 - tmp24
|
| 692 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 693 |
+
tmp39 = tmp36 == tmp38
|
| 694 |
+
tmp40 = tmp33 & tmp39
|
| 695 |
+
tmp41 = tmp31 | tmp40
|
| 696 |
+
tmp42 = tmp28 | tmp41
|
| 697 |
+
mask_mod_output = tmp42
|
| 698 |
+
|
| 699 |
+
# (grads) apply mask for fully masked block
|
| 700 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 701 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 702 |
+
if not PRESCALE_QK:
|
| 703 |
+
post_mod_scores *= RCP_LN2
|
| 704 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 705 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 706 |
+
# Compute dV.
|
| 707 |
+
ppT = pT
|
| 708 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 709 |
+
if IS_DIVISIBLE:
|
| 710 |
+
Di = tl.load(DELTA + offs_m1)
|
| 711 |
+
else:
|
| 712 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 713 |
+
# Compute dP and dS.
|
| 714 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 715 |
+
dsT = pT * (dpT - Di[None, :])
|
| 716 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 717 |
+
tmp43 = (dsT)
|
| 718 |
+
grad_scores = tmp43
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
if not IS_DIVISIBLE:
|
| 723 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 724 |
+
|
| 725 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 726 |
+
if not WRITE_DQ:
|
| 727 |
+
idx_b = off_z
|
| 728 |
+
idx_h = off_hq
|
| 729 |
+
idx_m = m
|
| 730 |
+
idx_n = n
|
| 731 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 732 |
+
|
| 733 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 734 |
+
dsT = grad_scores
|
| 735 |
+
if not IS_FULL_BLOCKS:
|
| 736 |
+
# (grads) apply mask for partially unmasked block
|
| 737 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 738 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 739 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 740 |
+
|
| 741 |
+
return dk, dv
|
| 742 |
+
|
| 743 |
+
# Utility triton funcs
|
| 744 |
+
@triton.jit
|
| 745 |
+
def get_offset_for_next_block(
|
| 746 |
+
loop_iter, col_indices, total_blocks,
|
| 747 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 748 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 749 |
+
):
|
| 750 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 751 |
+
return BLOCK
|
| 752 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 753 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 754 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 755 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 756 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 757 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 758 |
+
return offset
|
| 759 |
+
|
| 760 |
+
@triton.jit
|
| 761 |
+
def get_bounded_indices(indices, max_len=None):
|
| 762 |
+
return indices % max_len if max_len is not None else indices
|
| 763 |
+
|
| 764 |
+
@triton.jit
|
| 765 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 766 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 767 |
+
return tl.load(block_ptr)
|
| 768 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 769 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 770 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 771 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 772 |
+
else:
|
| 773 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 774 |
+
|
| 775 |
+
@triton.jit
|
| 776 |
+
def load_checked_2d(
|
| 777 |
+
ptr,
|
| 778 |
+
offs_m,
|
| 779 |
+
offs_n,
|
| 780 |
+
stride_m,
|
| 781 |
+
stride_n,
|
| 782 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 783 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 784 |
+
M_LEN: tl.constexpr,
|
| 785 |
+
N_LEN: tl.constexpr,
|
| 786 |
+
):
|
| 787 |
+
# Calculate final pointer if strides are provided
|
| 788 |
+
if stride_m is not None and stride_n is not None:
|
| 789 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 790 |
+
|
| 791 |
+
# Handle all masking cases
|
| 792 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 793 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 794 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 795 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 796 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 797 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 798 |
+
else: # Both divisible
|
| 799 |
+
return tl.load(ptr)
|
progress/SpecForge/cache/compiled_kernels/3z/c3zilfzjywngbdehwphwkhzpt6qcv6jecvzdajl2d5hb73xe6yzw.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=2,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 28 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 29 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 30 |
+
SPLIT_KV : tl.constexpr = 32
|
| 31 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 32 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 34 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 35 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 36 |
+
BLOCK_M : tl.constexpr = 512
|
| 37 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 38 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 39 |
+
BLOCK_N : tl.constexpr = 64
|
| 40 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 41 |
+
USE_TMA : tl.constexpr = False
|
| 42 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 43 |
+
Q = arg_Q
|
| 44 |
+
K = arg_K
|
| 45 |
+
V = arg_V
|
| 46 |
+
M = arg_M
|
| 47 |
+
L = arg_L
|
| 48 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 49 |
+
KV_IDX = arg_KV_IDX
|
| 50 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 51 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 52 |
+
|
| 53 |
+
# Sub notation for this kernel:
|
| 54 |
+
# Q: Query, K: Key, V: Value
|
| 55 |
+
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
|
| 56 |
+
# M: Number of queries, N: Number of keys/values
|
| 57 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 58 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 59 |
+
# BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
|
| 60 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
|
| 61 |
+
# (Modifiable) Config options:
|
| 62 |
+
# SPLIT_KV: number of blocks K & V are split into
|
| 63 |
+
# TILE_KV: length of each local KV split
|
| 64 |
+
# BLOCK_M: block size that Q is padded along seqlen dim.
|
| 65 |
+
# BLOCK_N: block size of K & V along N dimension.
|
| 66 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 67 |
+
#
|
| 68 |
+
# change of base out of the loop
|
| 69 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 70 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 71 |
+
# SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
|
| 72 |
+
# SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
|
| 73 |
+
|
| 74 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
|
| 75 |
+
#
|
| 76 |
+
# SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
|
| 77 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 78 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 79 |
+
#
|
| 80 |
+
#
|
| 81 |
+
# Output: ACC output accumulated across local KV split.
|
| 82 |
+
|
| 83 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 84 |
+
|
| 85 |
+
# Define Q Strides
|
| 86 |
+
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1
|
| 87 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 88 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
|
| 89 |
+
stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1
|
| 90 |
+
stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
Z = 1
|
| 94 |
+
ZKV = 1
|
| 95 |
+
HKV = 8
|
| 96 |
+
G: tl.constexpr = GQA_SHARED_HEADS
|
| 97 |
+
HQ = HKV * G
|
| 98 |
+
Q_LEN = ks0
|
| 99 |
+
KV_LEN = ks1
|
| 100 |
+
|
| 101 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 102 |
+
|
| 103 |
+
# Make sure each split is a multiple of BLOCK_N
|
| 104 |
+
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
|
| 105 |
+
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
|
| 106 |
+
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
|
| 107 |
+
|
| 108 |
+
off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV
|
| 109 |
+
off_zkv = off_z % ZKV
|
| 110 |
+
off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV
|
| 111 |
+
off_t = tl.program_id(1).to(INDEX_DTYPE)
|
| 112 |
+
|
| 113 |
+
q_offset = off_z * stride_qz + off_hkv * stride_qh
|
| 114 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 115 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 116 |
+
|
| 117 |
+
K = K + k_offset
|
| 118 |
+
V = V + v_offset
|
| 119 |
+
|
| 120 |
+
SPARSE_Z = 1
|
| 121 |
+
SPARSE_HQ = 1
|
| 122 |
+
|
| 123 |
+
sparse_idx_z = off_z % SPARSE_Z
|
| 124 |
+
sparse_idx_h = off_hkv % SPARSE_HQ
|
| 125 |
+
|
| 126 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 127 |
+
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
| 128 |
+
|
| 129 |
+
# initialize pointer to m and l
|
| 130 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 131 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 132 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 133 |
+
|
| 134 |
+
# initialize offsets
|
| 135 |
+
tl.device_assert(BLOCK_M % G == 0)
|
| 136 |
+
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
|
| 137 |
+
off_g = tl.arange(0, G) # [G]
|
| 138 |
+
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 139 |
+
offs_hq = offs_g + off_hkv * G
|
| 140 |
+
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
|
| 141 |
+
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 142 |
+
offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 143 |
+
offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 144 |
+
|
| 145 |
+
# Get HZ offsets for KV_NUM_BLKS and KV_IDX
|
| 146 |
+
stride_block_z, stride_block_h, stride_block_row = 1, 1, 1
|
| 147 |
+
sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
|
| 148 |
+
stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1
|
| 149 |
+
sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
|
| 150 |
+
|
| 151 |
+
# Calculate KV blocks that belong this CTA.
|
| 152 |
+
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
|
| 153 |
+
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
|
| 154 |
+
|
| 155 |
+
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
|
| 156 |
+
|
| 157 |
+
if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
| 158 |
+
q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
|
| 159 |
+
elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
| 160 |
+
q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
|
| 161 |
+
elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
|
| 162 |
+
q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
|
| 163 |
+
else:
|
| 164 |
+
q = tl.load(Q + q_offset + q_range)
|
| 165 |
+
|
| 166 |
+
q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 170 |
+
# find first kv block we are loading and the number of blocks we are loading
|
| 171 |
+
# Offset the kv_indices tensor by the correct batch and head
|
| 172 |
+
kv_indices = KV_IDX + sparse_idx_hz_offset
|
| 173 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
|
| 174 |
+
MAX_KV_IDX = 1
|
| 175 |
+
indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
|
| 176 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 177 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 178 |
+
# first kv block we're loading
|
| 179 |
+
|
| 180 |
+
# last valid block according to sparse mask
|
| 181 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 182 |
+
|
| 183 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 184 |
+
|
| 185 |
+
desc_k = None
|
| 186 |
+
desc_v = None
|
| 187 |
+
|
| 188 |
+
acc, l_i, m_i = forward_inner(
|
| 189 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 190 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 191 |
+
# accumulatd values
|
| 192 |
+
acc, l_i, m_i,
|
| 193 |
+
#offsets
|
| 194 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 195 |
+
off_n,
|
| 196 |
+
#block sparse data
|
| 197 |
+
kv_indices, kv_num_blocks,
|
| 198 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 199 |
+
MATMUL_PRECISION,
|
| 200 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 201 |
+
IS_FULL_BLOCKS=False,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 206 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 207 |
+
# apply mask_mod to them - only score_mod
|
| 208 |
+
if HAS_FULL_BLOCKS:
|
| 209 |
+
kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
|
| 210 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
|
| 211 |
+
# Assign full block in a reverse order for off_t. Prioritize the last CTA.
|
| 212 |
+
block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
|
| 213 |
+
block_n_end = block_n_start + TILE_KV_MULTIPLE
|
| 214 |
+
indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
|
| 215 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 216 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 217 |
+
|
| 218 |
+
# last valid block according to sparse mask
|
| 219 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 220 |
+
|
| 221 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 222 |
+
|
| 223 |
+
acc, l_i, m_i = forward_inner(
|
| 224 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 225 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 226 |
+
# accumulatd values
|
| 227 |
+
acc, l_i, m_i,
|
| 228 |
+
#offsets
|
| 229 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 230 |
+
off_n,
|
| 231 |
+
#block sparse data
|
| 232 |
+
kv_indices, kv_num_blocks,
|
| 233 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 234 |
+
MATMUL_PRECISION,
|
| 235 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 236 |
+
IS_FULL_BLOCKS=True,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
m_offset = off_t * stride_mt + off_z * stride_mz
|
| 240 |
+
l_offset = off_t * stride_lt + off_z * stride_lz
|
| 241 |
+
|
| 242 |
+
M_block_ptr = tl.make_block_ptr(
|
| 243 |
+
base=M + m_offset,
|
| 244 |
+
shape=(G, Q_LEN), # (G, M)
|
| 245 |
+
strides=(stride_mh, stride_mm),
|
| 246 |
+
offsets=(off_hkv*G, 0),
|
| 247 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 248 |
+
order=(1, 0)
|
| 249 |
+
)
|
| 250 |
+
L_block_ptr = tl.make_block_ptr(
|
| 251 |
+
base=L + l_offset,
|
| 252 |
+
shape=(G, Q_LEN), # (G, M)
|
| 253 |
+
strides=(stride_lh, stride_lm),
|
| 254 |
+
offsets=(off_hkv*G, 0),
|
| 255 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 256 |
+
order=(1, 0)
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
|
| 260 |
+
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
|
| 261 |
+
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
|
| 262 |
+
if SAFE_M_BOUNDARY:
|
| 263 |
+
tl.store(M_block_ptr, m_i)
|
| 264 |
+
tl.store(L_block_ptr, l_i)
|
| 265 |
+
else:
|
| 266 |
+
tl.store(M_block_ptr, m_i, boundary_check=(1,))
|
| 267 |
+
tl.store(L_block_ptr, l_i, boundary_check=(1,))
|
| 268 |
+
|
| 269 |
+
# -- store output
|
| 270 |
+
idx_z = off_z
|
| 271 |
+
idx_t = off_t
|
| 272 |
+
idx_hq = off_hkv*G + off_g[:, None, None]
|
| 273 |
+
idx_m = off_m[None, :, None]
|
| 274 |
+
idx_d = offs_vd[None, None, :]
|
| 275 |
+
|
| 276 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 277 |
+
acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
|
| 278 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0
|
| 279 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# Utility triton funcs
|
| 283 |
+
@triton.jit
|
| 284 |
+
def get_offset_for_next_block(
|
| 285 |
+
loop_iter, col_indices, total_blocks,
|
| 286 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 287 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 288 |
+
):
|
| 289 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 290 |
+
return BLOCK
|
| 291 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 292 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 293 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 294 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 295 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 296 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 297 |
+
return offset
|
| 298 |
+
|
| 299 |
+
@triton.jit
|
| 300 |
+
def get_bounded_indices(indices, max_len=None):
|
| 301 |
+
return indices % max_len if max_len is not None else indices
|
| 302 |
+
|
| 303 |
+
@triton.jit
|
| 304 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 305 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 306 |
+
return tl.load(block_ptr)
|
| 307 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 308 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 309 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 310 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 311 |
+
else:
|
| 312 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 313 |
+
|
| 314 |
+
@triton.jit
|
| 315 |
+
def load_checked_2d(
|
| 316 |
+
ptr,
|
| 317 |
+
offs_m,
|
| 318 |
+
offs_n,
|
| 319 |
+
stride_m,
|
| 320 |
+
stride_n,
|
| 321 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 322 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 323 |
+
M_LEN: tl.constexpr,
|
| 324 |
+
N_LEN: tl.constexpr,
|
| 325 |
+
):
|
| 326 |
+
# Calculate final pointer if strides are provided
|
| 327 |
+
if stride_m is not None and stride_n is not None:
|
| 328 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 329 |
+
|
| 330 |
+
# Handle all masking cases
|
| 331 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 332 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 333 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 334 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 335 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 336 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 337 |
+
else: # Both divisible
|
| 338 |
+
return tl.load(ptr)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# Common Imports
|
| 342 |
+
@triton.jit
|
| 343 |
+
def forward_block_mn(
|
| 344 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 345 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 346 |
+
# accumulated values
|
| 347 |
+
acc, l_i, m_i,
|
| 348 |
+
# Offsets
|
| 349 |
+
off_z, off_h, offs_m, offs_n,
|
| 350 |
+
# Offsets needed for TMA loads
|
| 351 |
+
kv_start,
|
| 352 |
+
kv_offset,
|
| 353 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 354 |
+
# Strides for K and V
|
| 355 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 356 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 357 |
+
|
| 358 |
+
):
|
| 359 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 360 |
+
PRESCALE_QK : tl.constexpr = False
|
| 361 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 362 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 363 |
+
WRITE_DQ : tl.constexpr = True
|
| 364 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 365 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 366 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 367 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 368 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 369 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 370 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 371 |
+
SPLIT_KV : tl.constexpr = 32
|
| 372 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 373 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 374 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 375 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 376 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 377 |
+
BLOCK_M : tl.constexpr = 512
|
| 378 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 379 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 380 |
+
BLOCK_N : tl.constexpr = 64
|
| 381 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 382 |
+
USE_TMA : tl.constexpr = False
|
| 383 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# -- load k --
|
| 387 |
+
# NB reversed order to since K is transposed
|
| 388 |
+
kv_base_offset = kv_start + kv_offset
|
| 389 |
+
|
| 390 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 391 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 392 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 393 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 394 |
+
|
| 395 |
+
k = tl.trans(k)
|
| 396 |
+
# -- compute qk ---
|
| 397 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 398 |
+
if not PRESCALE_QK:
|
| 399 |
+
qk *= SM_SCALE
|
| 400 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 401 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 402 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 403 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 404 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 405 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 406 |
+
|
| 407 |
+
tmp0 = (qk)
|
| 408 |
+
post_mod_scores = tmp0
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 412 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 413 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 414 |
+
|
| 415 |
+
if not IS_FULL_BLOCKS:
|
| 416 |
+
tmp1 = (m)
|
| 417 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 418 |
+
tmp3 = tmp1 < tmp2
|
| 419 |
+
tmp4 = (n)
|
| 420 |
+
tmp5 = tmp4 <= tmp1
|
| 421 |
+
tmp6 = tmp3 & tmp5
|
| 422 |
+
tmp7 = tmp1 >= tmp2
|
| 423 |
+
tmp8 = tmp4 < tmp2
|
| 424 |
+
tmp9 = tmp7 & tmp8
|
| 425 |
+
tmp10 = tmp8 == 0
|
| 426 |
+
tmp11 = tmp7 & tmp10
|
| 427 |
+
tmp12 = tmp1 - tmp2
|
| 428 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 429 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 430 |
+
tmp15 = tmp4 - tmp2
|
| 431 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 432 |
+
tmp17 = tmp14 == tmp16
|
| 433 |
+
tmp18 = tmp11 & tmp17
|
| 434 |
+
tmp19 = tmp9 | tmp18
|
| 435 |
+
tmp20 = tmp6 | tmp19
|
| 436 |
+
mask_mod_output = tmp20
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 440 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 441 |
+
# apply mask for partially unmasked blocks
|
| 442 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 443 |
+
|
| 444 |
+
if not PRESCALE_QK:
|
| 445 |
+
post_mod_scores *= RCP_LN2
|
| 446 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 447 |
+
|
| 448 |
+
# -- compute scaling constant ---
|
| 449 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 450 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 451 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 452 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 453 |
+
else:
|
| 454 |
+
m_ij_masked = m_ij
|
| 455 |
+
|
| 456 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 457 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 458 |
+
|
| 459 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 460 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 461 |
+
# m_ij
|
| 462 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 463 |
+
# # -- scale and update acc --
|
| 464 |
+
acc = acc * alpha[:, None]
|
| 465 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 466 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 467 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 468 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 469 |
+
|
| 470 |
+
# -- update m_i
|
| 471 |
+
m_i = m_ij
|
| 472 |
+
|
| 473 |
+
return acc, l_i, m_i
|
| 474 |
+
|
| 475 |
+
@triton.jit
|
| 476 |
+
def forward_inner(
|
| 477 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 478 |
+
q, K, V,
|
| 479 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 480 |
+
# accumulated values
|
| 481 |
+
acc, l_i, m_i,
|
| 482 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 483 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 484 |
+
off_z, off_h, offs_m, offs_n,
|
| 485 |
+
# Offsets needed for TMA loads
|
| 486 |
+
kv_start,
|
| 487 |
+
# blocksparse data
|
| 488 |
+
kv_indices, kv_num_blocks,
|
| 489 |
+
# start kv and end kv block
|
| 490 |
+
block_n_start, block_n_end,
|
| 491 |
+
MATMUL_PRECISION,
|
| 492 |
+
# Strides for K and V
|
| 493 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 494 |
+
IS_FULL_BLOCKS,
|
| 495 |
+
):
|
| 496 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 497 |
+
PRESCALE_QK : tl.constexpr = False
|
| 498 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 499 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 500 |
+
WRITE_DQ : tl.constexpr = True
|
| 501 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 502 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 503 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 504 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 505 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 506 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 507 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 508 |
+
SPLIT_KV : tl.constexpr = 32
|
| 509 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 510 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 511 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 512 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 513 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 514 |
+
BLOCK_M : tl.constexpr = 512
|
| 515 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 516 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 517 |
+
BLOCK_N : tl.constexpr = 64
|
| 518 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 519 |
+
USE_TMA : tl.constexpr = False
|
| 520 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 524 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 525 |
+
|
| 526 |
+
if PRESCALE_QK:
|
| 527 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 528 |
+
|
| 529 |
+
kv_offset = 0
|
| 530 |
+
|
| 531 |
+
# loop over k, v and update accumulator until block_n_end
|
| 532 |
+
for start_n in range(block_n_start, block_n_end):
|
| 533 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 534 |
+
if IS_DIVISIBLE:
|
| 535 |
+
acc, l_i, m_i = forward_block_mn(
|
| 536 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 537 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 538 |
+
# accumulated values
|
| 539 |
+
acc, l_i, m_i,
|
| 540 |
+
# Offsets
|
| 541 |
+
off_z, off_h, offs_m, offs_n,
|
| 542 |
+
# Offsets needed for TMA loads
|
| 543 |
+
kv_start,
|
| 544 |
+
kv_offset,
|
| 545 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 546 |
+
# Strides for K and V
|
| 547 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 548 |
+
IS_FULL_BLOCKS,
|
| 549 |
+
)
|
| 550 |
+
else:
|
| 551 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 552 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 553 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 554 |
+
# to the last block because it's faster a lot.
|
| 555 |
+
acc, l_i, m_i = forward_block_mn(
|
| 556 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 557 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 558 |
+
# accumulated values
|
| 559 |
+
acc, l_i, m_i,
|
| 560 |
+
# Offsets
|
| 561 |
+
off_z, off_h, offs_m, offs_n,
|
| 562 |
+
# Offsets needed for TMA loads
|
| 563 |
+
kv_start,
|
| 564 |
+
kv_offset,
|
| 565 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 566 |
+
# Strides for K and V
|
| 567 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 568 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
offset = get_offset_for_next_block(
|
| 574 |
+
start_n, kv_indices, kv_num_blocks,
|
| 575 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
offs_n = offs_n + offset
|
| 579 |
+
kv_offset += offset
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
return acc, l_i, m_i
|
progress/SpecForge/cache/compiled_kernels/4a/7887d45b1aa6124e232769adbe995f9cc2af0dd187cb9928540172d82c7b8631.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "BZAXIZYYJGUVREZ5ANMEKVK5UU77TPVNED7QAB22EKNJIKFVURYA"}
|